From c77a8196537ad6c956e213abc2c16bf4c388ad9a Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 11 Jun 2025 18:04:57 +0200 Subject: [PATCH 01/23] Added mat_znx_dft_mul_x_pow_minus_one --- backend/spqlios-arithmetic | 2 +- backend/src/mat_znx_dft_ops.rs | 214 +++++++++++++++++++++++++++++---- backend/src/vec_znx_dft_ops.rs | 96 +++++++++++++++ core/src/automorphism.rs | 4 +- core/src/gglwe_ciphertext.rs | 4 +- core/src/ggsw_ciphertext.rs | 8 +- core/src/keyswitch_key.rs | 4 +- 7 files changed, 297 insertions(+), 35 deletions(-) diff --git a/backend/spqlios-arithmetic b/backend/spqlios-arithmetic index 173b980..0ae9a7b 160000 --- a/backend/spqlios-arithmetic +++ b/backend/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 173b980c7b8a4f0523d04c2aed061c2e046e846c +Subproject commit 0ae9a7b5adf07ce0b1797562528dab8e28192238 diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs index 9656dfb..9ed71a0 100644 --- a/backend/src/mat_znx_dft_ops.rs +++ b/backend/src/mat_znx_dft_ops.rs @@ -2,8 +2,8 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, - VecZnxDftToRef, + Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, + ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero, }; pub trait MatZnxDftAlloc { @@ -38,6 +38,8 @@ pub trait MatZnxDftScratch { b_cols_out: usize, b_size: usize, ) -> usize; + + fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize; } /// This trait implements methods for vector matrix product, @@ -52,7 +54,7 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + fn mat_znx_dft_set_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) where R: MatZnxToMut, A: VecZnxDftToRef; @@ -64,11 +66,22 @@ pub trait MatZnxDftOps { /// * `res`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + fn mat_znx_dft_get_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) where R: VecZnxDftToMut, A: MatZnxToRef; + /// Multiplies A by (X^{k} - 1) and stores the result on R. + fn mat_znx_dft_mul_x_pow_minus_one(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: MatZnxToMut, + A: MatZnxToRef; + + /// Multiplies A by (X^{k} - 1). + fn mat_znx_dft_mul_x_pow_minus_one_inplace(&self, k: i64, a: &mut A, scratch: &mut Scratch) + where + A: MatZnxToMut; + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// @@ -149,10 +162,97 @@ impl MatZnxDftScratch for Module { ) as usize } } + + fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize { + let xpm1_dft: usize = self.bytes_of_scalar_znx(1); + let xpm1: usize = self.bytes_of_scalar_znx_dft(1); + let tmp: usize = self.bytes_of_vec_znx_dft(cols_out, size); + xpm1_dft + (xpm1 | 2 * tmp) + } } impl MatZnxDftOps for Module { - fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + fn mat_znx_dft_mul_x_pow_minus_one(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: MatZnxToMut, + A: MatZnxToRef, + { + let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: MatZnxDft<&[u8], FFT64> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!(res.rows(), a.rows()); + assert_eq!(res.cols_in(), a.cols_in()); + assert_eq!(res.cols_out(), a.cols_out()); + } + + let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1); + + { + let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1); + xpm1.data[0] = 1; + self.vec_znx_rotate_inplace(k, &mut xpm1, 0); + self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0); + } + + let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, res.cols_out(), res.size()); + let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, res.cols_out(), res.size()); + + (0..res.rows()).for_each(|row_i| { + (0..res.cols_in()).for_each(|col_j| { + self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j); + + (0..tmp_0.cols()).for_each(|i| { + self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i); + self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i); + }); + + self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_1); + }); + }) + } + + fn mat_znx_dft_mul_x_pow_minus_one_inplace(&self, k: i64, a: &mut A, scratch: &mut Scratch) + where + A: MatZnxToMut, + { + let mut a: MatZnxDft<&mut [u8], FFT64> = a.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + + let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1); + + { + let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1); + xpm1.data[0] = 1; + self.vec_znx_rotate_inplace(k, &mut xpm1, 0); + self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0); + } + + let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size()); + let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size()); + + (0..a.rows()).for_each(|row_i| { + (0..a.cols_in()).for_each(|col_j| { + self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j); + + (0..tmp_0.cols()).for_each(|i| { + self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i); + self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i); + }); + + self.mat_znx_dft_set_row(&mut a, row_i, col_j, &tmp_1); + }); + }) + } + + fn mat_znx_dft_set_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) where R: MatZnxToMut, A: VecZnxDftToRef, @@ -204,7 +304,7 @@ impl MatZnxDftOps for Module { } } - fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + fn mat_znx_dft_get_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) where R: VecZnxDftToMut, A: MatZnxToRef, @@ -376,7 +476,7 @@ mod tests { use super::{MatZnxDftAlloc, MatZnxDftScratch}; #[test] - fn vmp_prepare_row() { + fn vmp_set_row() { let module: Module = Module::::new(16); let basek: usize = 8; let mat_rows: usize = 4; @@ -395,8 +495,8 @@ mod tests { a.fill_uniform(basek, col_out, mat_size, &mut source); module.vec_znx_dft(1, 0, &mut a_dft, col_out, &a, col_out); }); - module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft); - module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in); + module.mat_znx_dft_set_row(&mut mat, row_i, col_in, &a_dft); + module.mat_znx_dft_get_row(&mut b_dft, &mat, row_i, col_in); assert_eq!(a_dft.raw(), b_dft.raw()); } } @@ -413,10 +513,10 @@ mod tests { let mat_size: usize = 6; let res_size: usize = a_size; - [1, 2].iter().for_each(|in_cols| { - [1, 2].iter().for_each(|out_cols| { - let a_cols: usize = *in_cols; - let res_cols: usize = *out_cols; + [1, 2].iter().for_each(|cols_in| { + [1, 2].iter().for_each(|cols_out| { + let a_cols: usize = *cols_in; + let res_cols: usize = *cols_out; let mat_rows: usize = a_size; let mat_cols_in: usize = a_cols; @@ -456,7 +556,7 @@ mod tests { module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; }); - module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); }); }); @@ -499,11 +599,11 @@ mod tests { let res_size: usize = a_size; let mut source: Source = Source::new([0u8; 32]); - [1, 2].iter().for_each(|in_cols| { - [1, 2].iter().for_each(|out_cols| { + [1, 2].iter().for_each(|cols_in| { + [1, 2].iter().for_each(|cols_out| { (0..res_size).for_each(|shift| { - let a_cols: usize = *in_cols; - let res_cols: usize = *out_cols; + let a_cols: usize = *cols_in; + let res_cols: usize = *cols_out; let mat_rows: usize = a_size; let mat_cols_in: usize = a_cols; @@ -543,7 +643,7 @@ mod tests { module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; }); - module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); }); }); @@ -601,13 +701,13 @@ mod tests { let mat_size: usize = 6; let res_size: usize = a_size; - [1, 2].iter().for_each(|in_cols| { - [1, 2].iter().for_each(|out_cols| { + [1, 2].iter().for_each(|cols_in| { + [1, 2].iter().for_each(|cols_out| { [1, 3, 6].iter().for_each(|digits| { let mut source: Source = Source::new([0u8; 32]); - let a_cols: usize = *in_cols; - let res_cols: usize = *out_cols; + let a_cols: usize = *cols_in; + let res_cols: usize = *cols_out; let mat_rows: usize = a_size; let mat_cols_in: usize = a_cols; @@ -652,7 +752,7 @@ mod tests { module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); tmp.at_mut(col_out_i, limb)[idx] = 0 as i64; }); - module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); }); }); @@ -697,4 +797,70 @@ mod tests { }); }); } + + #[test] + fn mat_znx_dft_mul_x_pow_minus_one() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let basek: usize = 8; + let rows: usize = 2; + let cols_in: usize = 2; + let cols_out: usize = 2; + let size: usize = 4; + + let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out)); + + let mut mat_want: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); + let mut mat_have: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); + + let mut tmp: VecZnx> = module.new_vec_znx(1, size); + let mut tmp_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(cols_out, size); + + let mut source: Source = Source::new([0u8; 32]); + + (0..mat_want.rows()).for_each(|row_i| { + (0..mat_want.cols_in()).for_each(|col_i| { + (0..cols_out).for_each(|j| { + tmp.fill_uniform(basek, 0, size, &mut source); + module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0); + }); + + module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft); + }); + }); + + let k: i64 = 1; + + module.mat_znx_dft_mul_x_pow_minus_one(k, &mut mat_have, &mat_want, scratch.borrow()); + + let mut have: VecZnx> = module.new_vec_znx(cols_out, size); + let mut want: VecZnx> = module.new_vec_znx(cols_out, size); + let mut tmp_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, size); + + (0..mat_want.rows()).for_each(|row_i| { + (0..mat_want.cols_in()).for_each(|col_i| { + module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i); + + (0..cols_out).for_each(|j| { + module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); + // module.vec_znx_big_normalize(basek, &mut want, j, &tmp_big, 0, scratch.borrow()); + module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow()); + module.vec_znx_rotate(k, &mut want, j, &tmp, 0); + module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0); + module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow()); + }); + + module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i); + + (0..cols_out).for_each(|j| { + module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); + module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow()); + }); + + assert_eq!(have, want) + }); + }); + } } diff --git a/backend/src/vec_znx_dft_ops.rs b/backend/src/vec_znx_dft_ops.rs index 963de18..5892155 100644 --- a/backend/src/vec_znx_dft_ops.rs +++ b/backend/src/vec_znx_dft_ops.rs @@ -53,6 +53,22 @@ pub trait VecZnxDftOps { R: VecZnxDftToMut, A: VecZnxDftToRef; + fn vec_znx_dft_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef; + + fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + + fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, @@ -150,6 +166,86 @@ impl VecZnxDftOps for Module { } } + fn vec_znx_dft_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + + fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + } + + fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + } + fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 27ea44a..f91ba41 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -70,7 +70,7 @@ impl> GetRow for AutomorphismKey { col_j: usize, res: &mut GLWECiphertextFourier, ) { - module.vmp_extract_row(&mut res.data, &self.key.0.data, row_i, col_j); + module.mat_znx_dft_get_row(&mut res.data, &self.key.0.data, row_i, col_j); } } @@ -82,7 +82,7 @@ impl + AsRef<[u8]>> SetRow for AutomorphismKey { col_j: usize, a: &GLWECiphertextFourier, ) { - module.vmp_prepare_row(&mut self.key.0.data, row_i, col_j, &a.data); + module.mat_znx_dft_set_row(&mut self.key.0.data, row_i, col_j, &a.data); } } diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 22d6749..e9f9684 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -219,7 +219,7 @@ impl> GetRow for GGLWECiphertext { col_j: usize, res: &mut GLWECiphertextFourier, ) { - module.vmp_extract_row(&mut res.data, &self.data, row_i, col_j); + module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); } } @@ -231,6 +231,6 @@ impl + AsRef<[u8]>> SetRow for GGLWECiphertext { col_j: usize, a: &GLWECiphertextFourier, ) { - module.vmp_prepare_row(&mut self.data, row_i, col_j, &a.data); + module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); } } diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 82e0e81..ff1f1e7 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -429,7 +429,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i); }); - module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); + module.mat_znx_dft_set_row(&mut self.data, row_i, 0, &ci_dft); // Generates // @@ -525,7 +525,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i); }); - module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); + module.mat_znx_dft_set_row(&mut self.data, row_i, 0, &ci_dft); // Generates // @@ -688,7 +688,7 @@ impl> GetRow for GGSWCiphertext { col_j: usize, res: &mut GLWECiphertextFourier, ) { - module.vmp_extract_row(&mut res.data, &self.data, row_i, col_j); + module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); } } @@ -700,6 +700,6 @@ impl + AsRef<[u8]>> SetRow for GGSWCiphertext, ) { - module.vmp_prepare_row(&mut self.data, row_i, col_j, &a.data); + module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); } } diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index 56d42b4..fd4da76 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -75,7 +75,7 @@ impl> GetRow for GLWESwitchingKey { col_j: usize, res: &mut GLWECiphertextFourier, ) { - module.vmp_extract_row(&mut res.data, &self.0.data, row_i, col_j); + module.mat_znx_dft_get_row(&mut res.data, &self.0.data, row_i, col_j); } } @@ -87,7 +87,7 @@ impl + AsRef<[u8]>> SetRow for GLWESwitchingKey col_j: usize, a: &GLWECiphertextFourier, ) { - module.vmp_prepare_row(&mut self.0.data, row_i, col_j, &a.data); + module.mat_znx_dft_set_row(&mut self.0.data, row_i, col_j, &a.data); } } From d826fcd5c8118326faa6df7ced88c48350f409f2 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 12 Jun 2025 10:13:47 +0200 Subject: [PATCH 02/23] Added binary key distributions --- backend/src/mat_znx_dft_ops.rs | 2 +- backend/src/scalar_znx.rs | 25 +++++++++++++++++++++++++ core/src/glwe_ciphertext.rs | 3 +++ core/src/glwe_keys.rs | 29 ++++++++++++++++++++++++++++- 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs index 9ed71a0..5f08a89 100644 --- a/backend/src/mat_znx_dft_ops.rs +++ b/backend/src/mat_znx_dft_ops.rs @@ -3,7 +3,7 @@ use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, - ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero, + ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, }; pub trait MatZnxDftAlloc { diff --git a/backend/src/scalar_znx.rs b/backend/src/scalar_znx.rs index cb51e0d..4c45f36 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/scalar_znx.rs @@ -72,6 +72,31 @@ impl + AsRef<[u8]>> ScalarZnx { .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); self.at_mut(col, 0).shuffle(source); } + + pub fn fill_binary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { + let choices: [i64; 2] = [0, 1]; + let weights: [f64; 2] = [1.0 - prob, prob]; + let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); + self.at_mut(col, 0) + .iter_mut() + .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); + } + + pub fn fill_binary_hw(&mut self, col: usize, hw: usize, source: &mut Source) { + assert!(hw <= self.n()); + self.at_mut(col, 0)[..hw] + .iter_mut() + .for_each(|x: &mut i64| *x = (source.next_u32() & 1) as i64); + self.at_mut(col, 0).shuffle(source); + } + + pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) { + assert!(self.n() % block_size == 0); + for chunk in self.at_mut(col, 0).chunks_mut(block_size) { + chunk[0] = 1; + chunk.shuffle(source); + } + } } impl>> ScalarZnx { diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 28c1724..cd85f0c 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -763,6 +763,9 @@ impl + AsMut<[u8]>> GLWECiphertext { ), SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu), + SecretDistribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu), + SecretDistribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu), SecretDistribution::ZERO => {} } diff --git a/core/src/glwe_keys.rs b/core/src/glwe_keys.rs index 8f04408..be8aa43 100644 --- a/core/src/glwe_keys.rs +++ b/core/src/glwe_keys.rs @@ -10,8 +10,11 @@ use crate::{GLWECiphertextFourier, Infos}; pub(crate) enum SecretDistribution { TernaryFixed(usize), // Ternary with fixed Hamming weight TernaryProb(f64), // Ternary with probabilistic Hamming weight + BinaryFixed(usize), // Binary with fixed Hamming weight + BinaryProb(f64), // Binary with probabilistic Hamming weight + BinaryBlock(usize), // Binary split in block of size 2^k ZERO, // Debug mod - NONE, + NONE, // Unitialized } pub struct GLWESecret { @@ -65,6 +68,30 @@ impl + AsRef<[u8]>> GLWESecret { self.dist = SecretDistribution::TernaryFixed(hw); } + pub fn fill_binary_prob(&mut self, module: &Module, prob: f64, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_prob(i, prob, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::BinaryProb(prob); + } + + pub fn fill_binary_hw(&mut self, module: &Module, hw: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_hw(i, hw, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::BinaryFixed(hw); + } + + pub fn fill_binary_block(&mut self, module: &Module, block_size: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_block(i, block_size, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::BinaryBlock(block_size); + } + pub fn fill_zero(&mut self) { self.data.zero(); self.dist = SecretDistribution::ZERO; From d5dc9e690201ef58e2e1bbe413d3fd69d25d59f3 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 12 Jun 2025 10:54:23 +0200 Subject: [PATCH 03/23] Fixed block binary secret generation & added CGGI blind rotation key generation --- backend/src/scalar_znx.rs | 12 ++++--- core/src/blind_rotation/ccgi.rs | 0 core/src/blind_rotation/key.rs | 60 +++++++++++++++++++++++++++++++++ core/src/blind_rotation/mod.rs | 2 ++ core/src/lib.rs | 1 + 5 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 core/src/blind_rotation/ccgi.rs create mode 100644 core/src/blind_rotation/key.rs create mode 100644 core/src/blind_rotation/mod.rs diff --git a/backend/src/scalar_znx.rs b/backend/src/scalar_znx.rs index 4c45f36..e252a9b 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/scalar_znx.rs @@ -91,10 +91,14 @@ impl + AsRef<[u8]>> ScalarZnx { } pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) { - assert!(self.n() % block_size == 0); - for chunk in self.at_mut(col, 0).chunks_mut(block_size) { - chunk[0] = 1; - chunk.shuffle(source); + assert!(block_size & (block_size - 1) == 0); + let max_idx: u64 = (block_size + 1) as u64; + let mask_idx: u64 = (2 * block_size - 1) as u64; + for block in self.at_mut(col, 0).chunks_mut(block_size) { + let idx: usize = source.next_u64n(max_idx, mask_idx) as usize; + if idx != block_size { + block[idx] = 1; + } } } } diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs new file mode 100644 index 0000000..e69de29 diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs new file mode 100644 index 0000000..4076e14 --- /dev/null +++ b/core/src/blind_rotation/key.rs @@ -0,0 +1,60 @@ +use backend::{Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}; +use sampling::source::Source; + +use crate::{AutomorphismKey, GGSWCiphertext, GLWESecret, SecretDistribution}; + +pub struct BlindRotationKeyCGGI { + pub(crate) data: Vec, B>>, + pub(crate) dist: SecretDistribution, +} + +pub struct BlindRotationKeyFHEW { + pub(crate) data: Vec, B>>, + pub(crate) auto: Vec, B>>, +} + +impl BlindRotationKeyCGGI { + pub fn allocate(module: &Module, lwe_degree: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + let mut data: Vec, FFT64>> = Vec::with_capacity(lwe_degree); + (0..lwe_degree).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank))); + Self { + data, + dist: SecretDistribution::NONE, + } + } + + pub fn generate_from_sk( + &mut self, + module: &Module, + sk_glwe: &GLWESecret, + sk_lwe: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DataSkGLWE: AsRef<[u8]>, + DataSkLWE: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.data.len(), sk_lwe.n()); + assert_eq!(sk_glwe.n(), module.n()); + assert_eq!(sk_glwe.rank(), self.data[0].rank()); + match sk_lwe.dist { + SecretDistribution::BinaryBlock(_) | SecretDistribution::BinaryFixed(_) | SecretDistribution::BinaryProb(_) => {} + _ => panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb"), + } + } + + self.dist = sk_lwe.dist; + + let mut pt: ScalarZnx> = module.new_scalar_znx(1); + let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref(); + + self.data.iter_mut().enumerate().for_each(|(i, ggsw)| { + pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; + ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch); + }) + } +} diff --git a/core/src/blind_rotation/mod.rs b/core/src/blind_rotation/mod.rs new file mode 100644 index 0000000..c531781 --- /dev/null +++ b/core/src/blind_rotation/mod.rs @@ -0,0 +1,2 @@ +// pub mod cggi; +pub mod key; diff --git a/core/src/lib.rs b/core/src/lib.rs index 69eb045..3aa05ea 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,4 +1,5 @@ pub mod automorphism; +pub mod blind_rotation; pub mod elem; pub mod gglwe_ciphertext; pub mod ggsw_ciphertext; From ec4253bb1cb056e674f8e35d9916a352a9e34a6c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 12 Jun 2025 11:03:54 +0200 Subject: [PATCH 04/23] Added LWESecret --- backend/src/scalar_znx.rs | 20 +++++----- core/src/blind_rotation/key.rs | 4 +- core/src/{glwe_keys.rs => keys.rs} | 60 ++++++++++++++++++++++++++++++ core/src/lib.rs | 4 +- 4 files changed, 74 insertions(+), 14 deletions(-) rename core/src/{glwe_keys.rs => keys.rs} (76%) diff --git a/backend/src/scalar_znx.rs b/backend/src/scalar_znx.rs index e252a9b..09e7292 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/scalar_znx.rs @@ -104,12 +104,12 @@ impl + AsRef<[u8]>> ScalarZnx { } impl>> ScalarZnx { - pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { - n * cols * size_of::() + pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { + n * cols * size_of::() } - pub(crate) fn new(n: usize, cols: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of::(n, cols)); + pub fn new(n: usize, cols: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of(n, cols)); Self { data: data.into(), n, @@ -117,9 +117,9 @@ impl>> ScalarZnx { } } - pub(crate) fn new_from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { + pub(crate) fn new_from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of::(n, cols)); + assert!(data.len() == Self::bytes_of(n, cols)); Self { data: data.into(), n, @@ -131,7 +131,7 @@ impl>> ScalarZnx { pub type ScalarZnxOwned = ScalarZnx>; pub(crate) fn bytes_of_scalar_znx(module: &Module, cols: usize) -> usize { - ScalarZnxOwned::bytes_of::(module.n(), cols) + ScalarZnxOwned::bytes_of(module.n(), cols) } pub trait ScalarZnxAlloc { @@ -142,13 +142,13 @@ pub trait ScalarZnxAlloc { impl ScalarZnxAlloc for Module { fn bytes_of_scalar_znx(&self, cols: usize) -> usize { - ScalarZnxOwned::bytes_of::(self.n(), cols) + ScalarZnxOwned::bytes_of(self.n(), cols) } fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned { - ScalarZnxOwned::new::(self.n(), cols) + ScalarZnxOwned::new(self.n(), cols) } fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { - ScalarZnxOwned::new_from_bytes::(self.n(), cols, bytes) + ScalarZnxOwned::new_from_bytes(self.n(), cols, bytes) } } diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index 4076e14..b4c5d40 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -1,7 +1,7 @@ use backend::{Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}; use sampling::source::Source; -use crate::{AutomorphismKey, GGSWCiphertext, GLWESecret, SecretDistribution}; +use crate::{AutomorphismKey, GGSWCiphertext, GLWESecret, LWESecret, SecretDistribution}; pub struct BlindRotationKeyCGGI { pub(crate) data: Vec, B>>, @@ -27,7 +27,7 @@ impl BlindRotationKeyCGGI { &mut self, module: &Module, sk_glwe: &GLWESecret, - sk_lwe: &GLWESecret, + sk_lwe: &LWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, diff --git a/core/src/glwe_keys.rs b/core/src/keys.rs similarity index 76% rename from core/src/glwe_keys.rs rename to core/src/keys.rs index be8aa43..9549f15 100644 --- a/core/src/glwe_keys.rs +++ b/core/src/keys.rs @@ -17,6 +17,66 @@ pub(crate) enum SecretDistribution { NONE, // Unitialized } +pub struct LWESecret { + pub(crate) data: ScalarZnx, + pub(crate) dist: SecretDistribution, +} + +impl LWESecret> { + pub fn alloc(n: usize) -> Self { + Self { + data: ScalarZnx::new(n, 1), + dist: SecretDistribution::NONE, + } + } +} + +impl LWESecret { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl + AsMut<[u8]>> LWESecret { + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { + self.data.fill_ternary_prob(0, prob, source); + self.dist = SecretDistribution::TernaryProb(prob); + } + + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + self.data.fill_ternary_hw(0, hw, source); + self.dist = SecretDistribution::TernaryFixed(hw); + } + + pub fn fill_binary_prob(&mut self, prob: f64, source: &mut Source) { + self.data.fill_binary_prob(0, prob, source); + self.dist = SecretDistribution::BinaryProb(prob); + } + + pub fn fill_binary_hw(&mut self, hw: usize, source: &mut Source) { + self.data.fill_binary_hw(0, hw, source); + self.dist = SecretDistribution::BinaryFixed(hw); + } + + pub fn fill_binary_block(&mut self, block_size: usize, source: &mut Source) { + self.data.fill_binary_block(0, block_size, source); + self.dist = SecretDistribution::BinaryBlock(block_size); + } + + pub fn fill_zero(&mut self) { + self.data.zero(); + self.dist = SecretDistribution::ZERO; + } +} + pub struct GLWESecret { pub(crate) data: ScalarZnx, pub(crate) data_fourier: ScalarZnxDft, diff --git a/core/src/lib.rs b/core/src/lib.rs index 3aa05ea..1bffc5a 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -5,10 +5,10 @@ pub mod gglwe_ciphertext; pub mod ggsw_ciphertext; pub mod glwe_ciphertext; pub mod glwe_ciphertext_fourier; -pub mod glwe_keys; pub mod glwe_ops; pub mod glwe_packing; pub mod glwe_plaintext; +pub mod keys; pub mod keyswitch_key; pub mod tensor_key; #[cfg(test)] @@ -24,10 +24,10 @@ pub use gglwe_ciphertext::*; pub use ggsw_ciphertext::*; pub use glwe_ciphertext::*; pub use glwe_ciphertext_fourier::*; -pub use glwe_keys::*; pub use glwe_ops::*; pub use glwe_packing::*; pub use glwe_plaintext::*; +pub use keys::*; pub use keyswitch_key::*; pub use tensor_key::*; From 989ea077a9967a9875baf8f97aa759fd6e71dab3 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 12 Jun 2025 15:46:05 +0200 Subject: [PATCH 05/23] Code organisation for glwe --- core/src/automorphism.rs | 10 +- core/src/blind_rotation/key.rs | 4 + core/src/elem.rs | 6 +- ..._ciphertext_fourier.rs => fourier_glwe.rs} | 24 +- core/src/{gglwe_ciphertext.rs => gglwe.rs} | 6 +- core/src/{ggsw_ciphertext.rs => ggsw.rs} | 18 +- core/src/glwe/automorphism.rs | 121 +++ core/src/glwe/ciphertext.rs | 115 +++ core/src/glwe/decryption.rs | 56 ++ core/src/glwe/encryption.rs | 254 +++++ core/src/glwe/external_product.rs | 129 +++ core/src/glwe/keyswitch.rs | 244 +++++ core/src/glwe/mod.rs | 31 + core/src/{glwe_ops.rs => glwe/ops.rs} | 0 core/src/{glwe_packing.rs => glwe/packing.rs} | 0 .../{glwe_plaintext.rs => glwe/plaintext.rs} | 6 +- core/src/glwe/public_key.rs | 75 ++ core/src/glwe/secret.rs | 93 ++ core/src/{ => glwe}/trace.rs | 0 core/src/glwe_ciphertext.rs | 886 ------------------ core/src/keys.rs | 226 ----- core/src/keyswitch_key.rs | 26 +- core/src/lib.rs | 37 +- core/src/lwe.rs | 64 ++ core/src/test_fft64/automorphism_key.rs | 11 +- core/src/test_fft64/gglwe.rs | 26 +- core/src/test_fft64/ggsw.rs | 30 +- core/src/test_fft64/glwe.rs | 8 +- core/src/test_fft64/glwe_fourier.rs | 24 +- core/src/test_fft64/tensor_key.rs | 4 +- 30 files changed, 1305 insertions(+), 1229 deletions(-) rename core/src/{glwe_ciphertext_fourier.rs => fourier_glwe.rs} (93%) rename core/src/{gglwe_ciphertext.rs => gglwe.rs} (97%) rename core/src/{ggsw_ciphertext.rs => ggsw.rs} (98%) create mode 100644 core/src/glwe/automorphism.rs create mode 100644 core/src/glwe/ciphertext.rs create mode 100644 core/src/glwe/decryption.rs create mode 100644 core/src/glwe/encryption.rs create mode 100644 core/src/glwe/external_product.rs create mode 100644 core/src/glwe/keyswitch.rs create mode 100644 core/src/glwe/mod.rs rename core/src/{glwe_ops.rs => glwe/ops.rs} (100%) rename core/src/{glwe_packing.rs => glwe/packing.rs} (100%) rename core/src/{glwe_plaintext.rs => glwe/plaintext.rs} (92%) create mode 100644 core/src/glwe/public_key.rs create mode 100644 core/src/glwe/secret.rs rename core/src/{ => glwe}/trace.rs (100%) delete mode 100644 core/src/glwe_ciphertext.rs create mode 100644 core/src/lwe.rs diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index f91ba41..3032532 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -2,7 +2,7 @@ use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, ScalarZnxOps, Scr use sampling::source::Source; use crate::{ - GGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, + FourierGLWECiphertext, GGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow, }; @@ -68,7 +68,7 @@ impl> GetRow for AutomorphismKey { module: &Module, row_i: usize, col_j: usize, - res: &mut GLWECiphertextFourier, + res: &mut FourierGLWECiphertext, ) { module.mat_znx_dft_get_row(&mut res.data, &self.key.0.data, row_i, col_j); } @@ -80,7 +80,7 @@ impl + AsRef<[u8]>> SetRow for AutomorphismKey { module: &Module, row_i: usize, col_j: usize, - a: &GLWECiphertextFourier, + a: &FourierGLWECiphertext, ) { module.mat_znx_dft_set_row(&mut self.key.0.data, row_i, col_j, &a.data); } @@ -127,8 +127,8 @@ impl AutomorphismKey, FFT64> { digits: usize, rank: usize, ) -> usize { - let tmp_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank); - let tmp_idft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let tmp_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); + let tmp_idft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); let idft: usize = module.vec_znx_idft_tmp_bytes(); let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); tmp_dft + tmp_idft + idft + keyswitch diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index b4c5d40..ff4a887 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -23,6 +23,10 @@ impl BlindRotationKeyCGGI { } } + pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) + } + pub fn generate_from_sk( &mut self, module: &Module, diff --git a/core/src/elem.rs b/core/src/elem.rs index ae7e5f7..e659e1e 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,6 @@ use backend::{Backend, Module, ZnxInfos}; -use crate::GLWECiphertextFourier; +use crate::{FourierGLWECiphertext, div_ceil}; pub trait Infos { type Inner: ZnxInfos; @@ -56,13 +56,13 @@ pub trait SetMetaData { } pub trait GetRow { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut FourierGLWECiphertext) where R: AsMut<[u8]> + AsRef<[u8]>; } pub trait SetRow { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &FourierGLWECiphertext) where R: AsRef<[u8]>; } diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/fourier_glwe.rs similarity index 93% rename from core/src/glwe_ciphertext_fourier.rs rename to core/src/fourier_glwe.rs index 19582f6..0e024a9 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/fourier_glwe.rs @@ -6,13 +6,13 @@ use sampling::source::Source; use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore}; -pub struct GLWECiphertextFourier { +pub struct FourierGLWECiphertext { pub data: VecZnxDft, pub basek: usize, pub k: usize, } -impl GLWECiphertextFourier, B> { +impl FourierGLWECiphertext, B> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { data: module.new_vec_znx_dft(rank + 1, k.div_ceil(basek)), @@ -26,7 +26,7 @@ impl GLWECiphertextFourier, B> { } } -impl Infos for GLWECiphertextFourier { +impl Infos for FourierGLWECiphertext { type Inner = VecZnxDft; fn inner(&self) -> &Self::Inner { @@ -42,13 +42,13 @@ impl Infos for GLWECiphertextFourier { } } -impl GLWECiphertextFourier { +impl FourierGLWECiphertext { pub fn rank(&self) -> usize { self.cols() - 1 } } -impl GLWECiphertextFourier, FFT64> { +impl FourierGLWECiphertext, FFT64> { #[allow(dead_code)] pub(crate) fn idft_scratch_space(module: &Module, basek: usize, k: usize) -> usize { module.bytes_of_vec_znx(1, k.div_ceil(basek)) @@ -125,7 +125,7 @@ impl GLWECiphertextFourier, FFT64> { } } -impl + AsRef<[u8]>> GLWECiphertextFourier { +impl + AsRef<[u8]>> FourierGLWECiphertext { pub fn encrypt_zero_sk>( &mut self, module: &Module, @@ -143,7 +143,7 @@ impl + AsRef<[u8]>> GLWECiphertextFourier pub fn keyswitch, DataRhs: AsRef<[u8]>>( &mut self, module: &Module, - lhs: &GLWECiphertextFourier, + lhs: &FourierGLWECiphertext, rhs: &GLWESwitchingKey, scratch: &mut Scratch, ) { @@ -159,7 +159,7 @@ impl + AsRef<[u8]>> GLWECiphertextFourier scratch: &mut Scratch, ) { unsafe { - let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; + let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; self.keyswitch(&module, &*self_ptr, rhs, scratch); } } @@ -167,7 +167,7 @@ impl + AsRef<[u8]>> GLWECiphertextFourier pub fn external_product, DataRhs: AsRef<[u8]>>( &mut self, module: &Module, - lhs: &GLWECiphertextFourier, + lhs: &FourierGLWECiphertext, rhs: &GGSWCiphertext, scratch: &mut Scratch, ) { @@ -184,7 +184,7 @@ impl + AsRef<[u8]>> GLWECiphertextFourier assert_eq!(lhs.n(), module.n()); assert!( scratch.available() - >= GLWECiphertextFourier::external_product_scratch_space( + >= FourierGLWECiphertext::external_product_scratch_space( module, self.basek(), self.k(), @@ -246,13 +246,13 @@ impl + AsRef<[u8]>> GLWECiphertextFourier scratch: &mut Scratch, ) { unsafe { - let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; + let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; self.external_product(&module, &*self_ptr, rhs, scratch); } } } -impl> GLWECiphertextFourier { +impl> FourierGLWECiphertext { pub fn decrypt + AsMut<[u8]>, DataSk: AsRef<[u8]>>( &self, module: &Module, diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe.rs similarity index 97% rename from core/src/gglwe_ciphertext.rs rename to core/src/gglwe.rs index e9f9684..66a5238 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe.rs @@ -4,7 +4,7 @@ use backend::{ }; use sampling::source::Source; -use crate::{GLWECiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow}; +use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESecret, GetRow, Infos, ScratchCore, SetRow, div_ceil}; pub struct GGLWECiphertext { pub(crate) data: MatZnxDft, @@ -217,7 +217,7 @@ impl> GetRow for GGLWECiphertext { module: &Module, row_i: usize, col_j: usize, - res: &mut GLWECiphertextFourier, + res: &mut FourierGLWECiphertext, ) { module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); } @@ -229,7 +229,7 @@ impl + AsRef<[u8]>> SetRow for GGLWECiphertext { module: &Module, row_i: usize, col_j: usize, - a: &GLWECiphertextFourier, + a: &FourierGLWECiphertext, ) { module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); } diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw.rs similarity index 98% rename from core/src/ggsw_ciphertext.rs rename to core/src/ggsw.rs index ff1f1e7..38842df 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw.rs @@ -6,8 +6,8 @@ use backend::{ use sampling::source::Source; use crate::{ - AutomorphismKey, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow, - TensorKey, + AutomorphismKey, FourierGLWECiphertext, GLWECiphertext, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow, + TensorKey, div_ceil, }; pub struct GGSWCiphertext { @@ -220,9 +220,9 @@ impl GGSWCiphertext, FFT64> { digits: usize, rank: usize, ) -> usize { - let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank); - let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); - let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); + let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); + let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ggsw: usize = FourierGLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); tmp_in + tmp_out + ggsw } @@ -234,9 +234,9 @@ impl GGSWCiphertext, FFT64> { digits: usize, rank: usize, ) -> usize { - let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); let ggsw: usize = - GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); + FourierGLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); tmp + ggsw } } @@ -686,7 +686,7 @@ impl> GetRow for GGSWCiphertext { module: &Module, row_i: usize, col_j: usize, - res: &mut GLWECiphertextFourier, + res: &mut FourierGLWECiphertext, ) { module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); } @@ -698,7 +698,7 @@ impl + AsRef<[u8]>> SetRow for GGSWCiphertext, row_i: usize, col_j: usize, - a: &GLWECiphertextFourier, + a: &FourierGLWECiphertext, ) { module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); } diff --git a/core/src/glwe/automorphism.rs b/core/src/glwe/automorphism.rs new file mode 100644 index 0000000..a4165aa --- /dev/null +++ b/core/src/glwe/automorphism.rs @@ -0,0 +1,121 @@ +use backend::{FFT64, Module, Scratch, VecZnxOps}; + +use crate::{AutomorphismKey, GLWECiphertext}; + +impl GLWECiphertext> { + pub fn automorphism_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + } + + pub fn automorphism_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn automorphism, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) { + self.keyswitch(module, lhs, &rhs.key, scratch); + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); + }) + } + + pub fn automorphism_inplace>( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) { + self.keyswitch_inplace(module, &rhs.key, scratch); + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); + }) + } + + pub fn automorphism_add, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) { + Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, lhs, &rhs.key, scratch); + } + + pub fn automorphism_add_inplace>( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + } + } + + pub fn automorphism_sub_ab, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) { + Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, lhs, &rhs.key, scratch); + } + + pub fn automorphism_sub_ab_inplace>( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + } + } + + pub fn automorphism_sub_ba, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) { + Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, lhs, &rhs.key, scratch); + } + + pub fn automorphism_sub_ba_inplace>( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + } + } +} diff --git a/core/src/glwe/ciphertext.rs b/core/src/glwe/ciphertext.rs new file mode 100644 index 0000000..ff634bd --- /dev/null +++ b/core/src/glwe/ciphertext.rs @@ -0,0 +1,115 @@ +use backend::{ + Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxToMut, + VecZnxToRef, +}; + +use crate::{FourierGLWECiphertext, GLWEOps, Infos, SetMetaData, div_ceil}; + +pub struct GLWECiphertext { + pub data: VecZnx, + pub basek: usize, + pub k: usize, +} + +impl GLWECiphertext> { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx(rank + 1, div_ceil(k, basek)), + basek, + k, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + module.bytes_of_vec_znx(rank + 1, div_ceil(k, basek)) + } +} + +impl Infos for GLWECiphertext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWECiphertext { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl> GLWECiphertext { + #[allow(dead_code)] + pub(crate) fn dft + AsRef<[u8]>>(&self, module: &Module, res: &mut FourierGLWECiphertext) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_dft(1, 0, &mut res.data, i, &self.data, i); + }) + } +} + +impl GLWECiphertext> { + pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + let size: usize = div_ceil(k, basek); + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } +} + +impl + AsRef<[u8]>> SetMetaData for GLWECiphertext { + fn set_k(&mut self, k: usize) { + self.k = k + } + + fn set_basek(&mut self, basek: usize) { + self.basek = basek + } +} + +pub trait GLWECiphertextToRef { + fn to_ref(&self) -> GLWECiphertext<&[u8]>; +} + +impl> GLWECiphertextToRef for GLWECiphertext { + fn to_ref(&self) -> GLWECiphertext<&[u8]> { + GLWECiphertext { + data: self.data.to_ref(), + basek: self.basek, + k: self.k, + } + } +} + +pub trait GLWECiphertextToMut { + fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; +} + +impl + AsRef<[u8]>> GLWECiphertextToMut for GLWECiphertext { + fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { + GLWECiphertext { + data: self.data.to_mut(), + basek: self.basek, + k: self.k, + } + } +} + +impl GLWEOps for GLWECiphertext +where + D: AsRef<[u8]> + AsMut<[u8]>, + GLWECiphertext: GLWECiphertextToMut + Infos + SetMetaData, +{ +} diff --git a/core/src/glwe/decryption.rs b/core/src/glwe/decryption.rs new file mode 100644 index 0000000..dd6428d --- /dev/null +++ b/core/src/glwe/decryption.rs @@ -0,0 +1,56 @@ +use backend::{FFT64, Module, ScalarZnxDftOps, Scratch, VecZnxBigOps, VecZnxDftOps, ZnxZero}; + +use crate::{GLWECiphertext, GLWEPlaintext, GLWESecret, Infos}; + +impl> GLWECiphertext { + pub fn clone(&self) -> GLWECiphertext> { + GLWECiphertext { + data: self.data.clone(), + basek: self.basek(), + k: self.k(), + } + } + + pub fn decrypt + AsRef<[u8]>, DataSk: AsRef<[u8]>>( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk: &GLWESecret, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk.n(), module.n()); + } + + let cols: usize = self.rank() + 1; + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + c0_big.zero(); + + { + (1..cols).for_each(|i| { + // ci_dft = DFT(a[i]) * DFT(s[i]) + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); + module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); + let ci_big = module.vec_znx_idft_consume(ci_dft); + + // c0_big += a[i] * s[i] + module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); + }); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0); + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut c0_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } +} diff --git a/core/src/glwe/encryption.rs b/core/src/glwe/encryption.rs new file mode 100644 index 0000000..1910f98 --- /dev/null +++ b/core/src/glwe/encryption.rs @@ -0,0 +1,254 @@ +use backend::{ + AddNormal, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxAlloc, VecZnxBig, + VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, +}; +use sampling::source::Source; + +use crate::{GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, SIX_SIGMA, div_ceil, keys::SecretDistribution}; + +impl GLWECiphertext> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + let size: usize = div_ceil(k, basek); + module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size) + } + pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + let size: usize = div_ceil(k, basek); + ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn encrypt_sk, DataSk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + sk: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + self.encrypt_sk_private( + module, + Some((pt, 0)), + sk, + source_xa, + source_xe, + sigma, + scratch, + ); + } + + pub fn encrypt_zero_sk>( + &mut self, + module: &Module, + sk: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + self.encrypt_sk_private( + module, + None::<(&GLWEPlaintext>, usize)>, + sk, + source_xa, + source_xe, + sigma, + scratch, + ); + } + + pub fn encrypt_pk, DataPk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + self.encrypt_pk_private( + module, + Some((pt, 0)), + pk, + source_xu, + source_xe, + sigma, + scratch, + ); + } + + pub fn encrypt_zero_pk>( + &mut self, + module: &Module, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + self.encrypt_pk_private( + module, + None::<(&GLWEPlaintext>, usize)>, + pk, + source_xu, + source_xe, + sigma, + scratch, + ); + } + + pub(crate) fn encrypt_sk_private, DataSk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + sk: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk.rank()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.n(), module.n()); + if let Some((pt, col)) = pt { + assert_eq!(pt.n(), module.n()); + assert!(col < self.rank() + 1); + } + assert!( + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), + "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", + scratch.available(), + GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) + ) + } + + let basek: usize = self.basek(); + let k: usize = self.k(); + let size: usize = self.size(); + let cols: usize = self.rank() + 1; + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + c0_big.zero(); + + { + // c[i] = uniform + // c[0] -= c[i] * s[i], + (1..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size); + + // c[i] = uniform + self.data.fill_uniform(basek, i, size, source_xa); + + // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) + module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); + module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + + // use c[0] as buffer, which is overwritten later by the normalization step + module.vec_znx_big_normalize(basek, &mut self.data, 0, &ci_big, 0, scratch_2); + + // c0_tmp = -c[i] * s[i] (use c[0] as buffer) + module.vec_znx_sub_ab_inplace(&mut c0_big, 0, &self.data, 0); + + // c[i] += m if col = i + if let Some((pt, col)) = pt { + if i == col { + module.vec_znx_add_inplace(&mut self.data, i, &pt.data, 0); + module.vec_znx_normalize_inplace(basek, &mut self.data, i, scratch_2); + } + } + }); + } + + // c[0] += e + c0_big.add_normal(basek, 0, k, source_xe, sigma, sigma * SIX_SIGMA); + + // c[0] += m if col = 0 + if let Some((pt, col)) = pt { + if col == 0 { + module.vec_znx_add_inplace(&mut c0_big, 0, &pt.data, 0); + } + } + + // c[0] = norm(c[0]) + module.vec_znx_normalize(basek, &mut self.data, 0, &c0_big, 0, scratch_1); + } + + pub(crate) fn encrypt_pk_private, DataPk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.basek(), pk.basek()); + assert_eq!(self.n(), module.n()); + assert_eq!(pk.n(), module.n()); + assert_eq!(self.rank(), pk.rank()); + if let Some((pt, _)) = pt { + assert_eq!(pt.basek(), pk.basek()); + assert_eq!(pt.n(), module.n()); + } + } + + let basek: usize = pk.basek(); + let size_pk: usize = pk.size(); + let cols: usize = self.rank() + 1; + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ + Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu), + SecretDistribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu), + SecretDistribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu), + SecretDistribution::ZERO => {} + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + // ct[i] = pk[i] * u + ei (+ m if col = i) + (0..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk); + // ci_dft = DFT(u) * DFT(pk[i]) + module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data.data, i); + + // ci_big = u * p[i] + let mut ci_big = module.vec_znx_idft_consume(ci_dft); + + // ci_big = u * pk[i] + e + ci_big.add_normal(basek, 0, pk.k(), source_xe, sigma, sigma * SIX_SIGMA); + + // ci_big = u * pk[i] + e + m (if col = i) + if let Some((pt, col)) = pt { + if col == i { + module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0); + } + } + + // ct[i] = norm(ci_big) + module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2); + }); + } +} diff --git a/core/src/glwe/external_product.rs b/core/src/glwe/external_product.rs new file mode 100644 index 0000000..c44ab75 --- /dev/null +++ b/core/src/glwe/external_product.rs @@ -0,0 +1,129 @@ +use backend::{ + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxBig, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxScratch, +}; + +use crate::{FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, Infos, div_ceil}; + +impl GLWECiphertext> { + pub fn external_product_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + ggsw_k: usize, + digits: usize, + rank: usize, + ) -> usize { + let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); + let out_size: usize = div_ceil(k_out, basek); + let ggsw_size: usize = div_ceil(ggsw_k, basek); + let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + + module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, // rows + rank + 1, // cols in + rank + 1, // cols out + ggsw_size, + ); + let normalize: usize = module.vec_znx_normalize_tmp_bytes(); + res_dft + (vmp | normalize) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + ggsw_k: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::external_product_scratch_space(module, basek, k_out, k_out, ggsw_k, digits, rank) + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn external_product, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::external_product_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank(), + ) + ); + } + + let cols: usize = rhs.rank() + 1; + let digits: usize = rhs.digits(); + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); + + { + (0..digits).for_each(|di| { + // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits + a_dft.set_size((lhs.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + + (0..cols).for_each(|col_i| { + module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); + } + }); + } + + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); + }); + } + + pub fn external_product_inplace>( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.external_product(&module, &*self_ptr, rhs, scratch); + } + } +} diff --git a/core/src/glwe/keyswitch.rs b/core/src/glwe/keyswitch.rs new file mode 100644 index 0000000..eace187 --- /dev/null +++ b/core/src/glwe/keyswitch.rs @@ -0,0 +1,244 @@ +use backend::{ + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, + VecZnxDftOps, ZnxZero, +}; + +use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESwitchingKey, Infos, div_ceil}; + +impl GLWECiphertext> { + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank_out + 1); + let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); + let out_size: usize = div_ceil(k_out, basek); + let ksk_size: usize = div_ceil(k_ksk, basek); + let ai_dft: usize = module.bytes_of_vec_znx_dft(rank_in, in_size); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) + + module.bytes_of_vec_znx_dft(rank_in, in_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + return res_dft + ((ai_dft + vmp) | normalize); + } + + pub fn keyswitch_from_fourier_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn keyswitch, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + Self::keyswitch_private::<_, _, 0>(self, 0, module, lhs, rhs, scratch); + } + + pub fn keyswitch_inplace>( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } + } + + pub(crate) fn keyswitch_private, DataRhs: AsRef<[u8]>, const OP: u8>( + &mut self, + apply_auto: i64, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + let digits: usize = rhs.digits(); + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + digits - 1) / digits); + ai_dft.zero(); + { + (0..digits).for_each(|di| { + ai_dft.set_size((lhs.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft( + digits, + digits - di - 1, + &mut ai_dft, + col_i, + &lhs.data, + col_i + 1, + ); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); + } + }); + } + + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_add_small_inplace(&mut res_big, 0, &lhs.data, 0); + + (0..cols_out).for_each(|i| { + if apply_auto != 0 { + module.vec_znx_big_automorphism_inplace(apply_auto, &mut res_big, i); + } + + match OP { + 1 => module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i), + 2 => module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i), + 3 => module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i), + _ => {} + } + module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); + }); + } + + pub(crate) fn keyswitch_from_fourier, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &FourierGLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_from_fourier_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + // Buffer of the result of VMP in DFT + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + let digits = rhs.digits(); + + (0..digits).for_each(|di| { + // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + di) / digits); + + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft_copy( + digits, + digits - 1 - di, + &mut ai_dft, + col_i, + &lhs.data, + col_i + 1, + ); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); + } + }); + } + + module.vec_znx_dft_add_inplace(&mut res_dft, 0, &lhs.data, 0); + + // Switches result of VMP outside of DFT + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); + }); + } +} diff --git a/core/src/glwe/mod.rs b/core/src/glwe/mod.rs new file mode 100644 index 0000000..ed3395d --- /dev/null +++ b/core/src/glwe/mod.rs @@ -0,0 +1,31 @@ +pub mod automorphism; +pub mod ciphertext; +pub mod decryption; +pub mod encryption; +pub mod external_product; +pub mod keyswitch; +pub mod ops; +pub mod packing; +pub mod plaintext; +pub mod public_key; +pub mod secret; +pub mod trace; + +#[allow(unused_imports)] +pub use automorphism::*; +pub use ciphertext::*; +#[allow(unused_imports)] +pub use decryption::*; +#[allow(unused_imports)] +pub use encryption::*; +#[allow(unused_imports)] +pub use external_product::*; +#[allow(unused_imports)] +pub use keyswitch::*; +pub use ops::*; +pub use packing::*; +pub use plaintext::*; +pub use public_key::*; +pub use secret::*; +#[allow(unused_imports)] +pub use trace::*; diff --git a/core/src/glwe_ops.rs b/core/src/glwe/ops.rs similarity index 100% rename from core/src/glwe_ops.rs rename to core/src/glwe/ops.rs diff --git a/core/src/glwe_packing.rs b/core/src/glwe/packing.rs similarity index 100% rename from core/src/glwe_packing.rs rename to core/src/glwe/packing.rs diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe/plaintext.rs similarity index 92% rename from core/src/glwe_plaintext.rs rename to core/src/glwe/plaintext.rs index 5bebc68..c1e9175 100644 --- a/core/src/glwe_plaintext.rs +++ b/core/src/glwe/plaintext.rs @@ -1,6 +1,10 @@ use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; -use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData}; +use crate::{ + GLWEOps, Infos, SetMetaData, + ciphertext::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef}, + div_ceil, +}; pub struct GLWEPlaintext { pub data: VecZnx, diff --git a/core/src/glwe/public_key.rs b/core/src/glwe/public_key.rs new file mode 100644 index 0000000..4a1ed15 --- /dev/null +++ b/core/src/glwe/public_key.rs @@ -0,0 +1,75 @@ +use backend::{Backend, FFT64, Module, ScratchOwned, VecZnxDft}; +use sampling::source::Source; + +use crate::{FourierGLWECiphertext, GLWESecret, Infos, keys::SecretDistribution}; + +pub struct GLWEPublicKey { + pub(crate) data: FourierGLWECiphertext, + pub(crate) dist: SecretDistribution, +} + +impl GLWEPublicKey, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: FourierGLWECiphertext::alloc(module, basek, k, rank), + dist: SecretDistribution::NONE, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + FourierGLWECiphertext::, B>::bytes_of(module, basek, k, rank) + } +} + +impl Infos for GLWEPublicKey { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data.data + } + + fn basek(&self) -> usize { + self.data.basek + } + + fn k(&self) -> usize { + self.data.k + } +} + +impl GLWEPublicKey { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl + AsMut<[u8]>> GLWEPublicKey { + pub fn generate_from_sk>( + &mut self, + module: &Module, + sk: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + ) { + #[cfg(debug_assertions)] + { + match sk.dist { + SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), + _ => {} + } + } + + // Its ok to allocate scratch space here since pk is usually generated only once. + let mut scratch: ScratchOwned = ScratchOwned::new(FourierGLWECiphertext::encrypt_sk_scratch_space( + module, + self.basek(), + self.k(), + self.rank(), + )); + + self.data + .encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch.borrow()); + self.dist = sk.dist; + } +} diff --git a/core/src/glwe/secret.rs b/core/src/glwe/secret.rs new file mode 100644 index 0000000..b704365 --- /dev/null +++ b/core/src/glwe/secret.rs @@ -0,0 +1,93 @@ +use backend::{ + Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ZnxInfos, ZnxZero, +}; +use sampling::source::Source; + +use crate::keys::SecretDistribution; + +pub struct GLWESecret { + pub(crate) data: ScalarZnx, + pub(crate) data_fourier: ScalarZnxDft, + pub(crate) dist: SecretDistribution, +} + +impl GLWESecret, B> { + pub fn alloc(module: &Module, rank: usize) -> Self { + Self { + data: module.new_scalar_znx(rank), + data_fourier: module.new_scalar_znx_dft(rank), + dist: SecretDistribution::NONE, + } + } + + pub fn bytes_of(module: &Module, rank: usize) -> usize { + module.bytes_of_scalar_znx(rank) + module.bytes_of_scalar_znx_dft(rank) + } +} + +impl GLWESecret { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl + AsRef<[u8]>> GLWESecret { + pub fn fill_ternary_prob(&mut self, module: &Module, prob: f64, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_prob(i, prob, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::TernaryProb(prob); + } + + pub fn fill_ternary_hw(&mut self, module: &Module, hw: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_hw(i, hw, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::TernaryFixed(hw); + } + + pub fn fill_binary_prob(&mut self, module: &Module, prob: f64, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_prob(i, prob, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::BinaryProb(prob); + } + + pub fn fill_binary_hw(&mut self, module: &Module, hw: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_hw(i, hw, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::BinaryFixed(hw); + } + + pub fn fill_binary_block(&mut self, module: &Module, block_size: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_block(i, block_size, source); + }); + self.prep_fourier(module); + self.dist = SecretDistribution::BinaryBlock(block_size); + } + + pub fn fill_zero(&mut self) { + self.data.zero(); + self.dist = SecretDistribution::ZERO; + } + + pub(crate) fn prep_fourier(&mut self, module: &Module) { + (0..self.rank()).for_each(|i| { + module.svp_prepare(&mut self.data_fourier, i, &self.data, i); + }); + } +} diff --git a/core/src/trace.rs b/core/src/glwe/trace.rs similarity index 100% rename from core/src/trace.rs rename to core/src/glwe/trace.rs diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs deleted file mode 100644 index cd85f0c..0000000 --- a/core/src/glwe_ciphertext.rs +++ /dev/null @@ -1,886 +0,0 @@ -use backend::{ - AddNormal, Backend, FFT64, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, - ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, - VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero, -}; -use sampling::source::Source; - -use crate::{ - AutomorphismKey, GGSWCiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWEPublicKey, GLWESecret, GLWESwitchingKey, - Infos, SIX_SIGMA, SecretDistribution, SetMetaData, -}; - -pub struct GLWECiphertext { - pub data: VecZnx, - pub basek: usize, - pub k: usize, -} - -impl GLWECiphertext> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { - Self { - data: module.new_vec_znx(rank + 1, k.div_ceil(basek)), - basek, - k, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) - } -} - -impl Infos for GLWECiphertext { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GLWECiphertext { - pub fn rank(&self) -> usize { - self.cols() - 1 - } -} - -impl> GLWECiphertext { - #[allow(dead_code)] - pub(crate) fn dft + AsRef<[u8]>>(&self, module: &Module, res: &mut GLWECiphertextFourier) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), res.rank()); - assert_eq!(self.basek(), res.basek()) - } - - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_dft(1, 0, &mut res.data, i, &self.data, i); - }) - } -} - -impl GLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = k.div_ceil(basek); - module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size) - } - pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = k.div_ceil(basek); - ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1)) - + module.bytes_of_scalar_znx_dft(1) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = k.div_ceil(basek); - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } - - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank_out + 1); - let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); - let out_size: usize = k_out.div_ceil(basek); - let ksk_size: usize = k_ksk.div_ceil(basek); - let ai_dft: usize = module.bytes_of_vec_znx_dft(rank_in, in_size); - let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) - + module.bytes_of_vec_znx_dft(rank_in, in_size); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - return res_dft + ((ai_dft + vmp) | normalize); - } - - pub fn keyswitch_from_fourier_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) - } - - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) - } - - pub fn automorphism_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) - } - - pub fn automorphism_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) - } - - pub fn external_product_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - ggsw_k: usize, - digits: usize, - rank: usize, - ) -> usize { - let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); - let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); - let out_size: usize = k_out.div_ceil(basek); - let ggsw_size: usize = ggsw_k.div_ceil(basek); - let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) - + module.vmp_apply_tmp_bytes( - out_size, - in_size, - in_size, // rows - rank + 1, // cols in - rank + 1, // cols out - ggsw_size, - ); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - res_dft + (vmp | normalize) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - ggsw_k: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::external_product_scratch_space(module, basek, k_out, k_out, ggsw_k, digits, rank) - } -} - -impl + AsRef<[u8]>> SetMetaData for GLWECiphertext { - fn set_k(&mut self, k: usize) { - self.k = k - } - - fn set_basek(&mut self, basek: usize) { - self.basek = basek - } -} - -impl + AsMut<[u8]>> GLWECiphertext { - pub fn encrypt_sk, DataSk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_sk_private( - module, - Some((pt, 0)), - sk, - source_xa, - source_xe, - sigma, - scratch, - ); - } - - pub fn encrypt_zero_sk>( - &mut self, - module: &Module, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_sk_private( - module, - None::<(&GLWEPlaintext>, usize)>, - sk, - source_xa, - source_xe, - sigma, - scratch, - ); - } - - pub fn encrypt_pk, DataPk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - pk: &GLWEPublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_pk_private( - module, - Some((pt, 0)), - pk, - source_xu, - source_xe, - sigma, - scratch, - ); - } - - pub fn encrypt_zero_pk>( - &mut self, - module: &Module, - pk: &GLWEPublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_pk_private( - module, - None::<(&GLWEPlaintext>, usize)>, - pk, - source_xu, - source_xe, - sigma, - scratch, - ); - } - - pub fn automorphism, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - self.keyswitch(module, lhs, &rhs.key, scratch); - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); - }) - } - - pub fn automorphism_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - self.keyswitch_inplace(module, &rhs.key, scratch); - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); - }) - } - - pub fn automorphism_add, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, lhs, &rhs.key, scratch); - } - - pub fn automorphism_add_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); - } - } - - pub fn automorphism_sub_ab, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, lhs, &rhs.key, scratch); - } - - pub fn automorphism_sub_ab_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); - } - } - - pub fn automorphism_sub_ba, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, lhs, &rhs.key, scratch); - } - - pub fn automorphism_sub_ba_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); - } - } - - pub(crate) fn keyswitch_from_fourier, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertextFourier, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(lhs.rank(), rhs.rank_in()); - assert_eq!(self.rank(), rhs.rank_out()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertext::keyswitch_from_fourier_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ) - ); - } - - let cols_in: usize = rhs.rank_in(); - let cols_out: usize = rhs.rank_out() + 1; - - // Buffer of the result of VMP in DFT - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise - - { - let digits = rhs.digits(); - - (0..digits).for_each(|di| { - // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + di) / digits); - - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft_copy( - digits, - digits - 1 - di, - &mut ai_dft, - col_i, - &lhs.data, - col_i + 1, - ); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); - } - }); - } - - module.vec_znx_dft_add_inplace(&mut res_dft, 0, &lhs.data, 0); - - // Switches result of VMP outside of DFT - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); - - (0..cols_out).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); - }); - } - - pub fn keyswitch, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 0>(self, 0, module, lhs, rhs, scratch); - } - - pub(crate) fn keyswitch_private, DataRhs: AsRef<[u8]>, const OP: u8>( - &mut self, - apply_auto: i64, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(lhs.rank(), rhs.rank_in()); - assert_eq!(self.rank(), rhs.rank_out()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertext::keyswitch_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ) - ); - } - - let cols_in: usize = rhs.rank_in(); - let cols_out: usize = rhs.rank_out() + 1; - let digits: usize = rhs.digits(); - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + digits - 1) / digits); - ai_dft.zero(); - { - (0..digits).for_each(|di| { - ai_dft.set_size((lhs.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft( - digits, - digits - di - 1, - &mut ai_dft, - col_i, - &lhs.data, - col_i + 1, - ); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); - } - }); - } - - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - module.vec_znx_big_add_small_inplace(&mut res_big, 0, &lhs.data, 0); - - (0..cols_out).for_each(|i| { - if apply_auto != 0 { - module.vec_znx_big_automorphism_inplace(apply_auto, &mut res_big, i); - } - - match OP { - 1 => module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i), - 2 => module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i), - 3 => module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i), - _ => {} - } - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); - }); - } - - pub fn keyswitch_inplace>( - &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - self.keyswitch(&module, &*self_ptr, rhs, scratch); - } - } - - pub fn external_product, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(rhs.rank(), lhs.rank()); - assert_eq!(rhs.rank(), self.rank()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertext::external_product_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank(), - ) - ); - } - - let cols: usize = rhs.rank() + 1; - let digits: usize = rhs.digits(); - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); - - { - (0..digits).for_each(|di| { - // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits - a_dft.set_size((lhs.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - - (0..cols).for_each(|col_i| { - module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); - } - }); - } - - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - (0..cols).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); - }); - } - - pub fn external_product_inplace>( - &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - self.external_product(&module, &*self_ptr, rhs, scratch); - } - } - - pub(crate) fn encrypt_sk_private, DataSk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(sk.n(), module.n()); - assert_eq!(self.n(), module.n()); - if let Some((pt, col)) = pt { - assert_eq!(pt.n(), module.n()); - assert!(col < self.rank() + 1); - } - assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), - "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", - scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) - ) - } - - let basek: usize = self.basek(); - let k: usize = self.k(); - let size: usize = self.size(); - let cols: usize = self.rank() + 1; - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size); - c0_big.zero(); - - { - // c[i] = uniform - // c[0] -= c[i] * s[i], - (1..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size); - - // c[i] = uniform - self.data.fill_uniform(basek, i, size, source_xa); - - // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) - module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); - module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); - let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); - - // use c[0] as buffer, which is overwritten later by the normalization step - module.vec_znx_big_normalize(basek, &mut self.data, 0, &ci_big, 0, scratch_2); - - // c0_tmp = -c[i] * s[i] (use c[0] as buffer) - module.vec_znx_sub_ab_inplace(&mut c0_big, 0, &self.data, 0); - - // c[i] += m if col = i - if let Some((pt, col)) = pt { - if i == col { - module.vec_znx_add_inplace(&mut self.data, i, &pt.data, 0); - module.vec_znx_normalize_inplace(basek, &mut self.data, i, scratch_2); - } - } - }); - } - - // c[0] += e - c0_big.add_normal(basek, 0, k, source_xe, sigma, sigma * SIX_SIGMA); - - // c[0] += m if col = 0 - if let Some((pt, col)) = pt { - if col == 0 { - module.vec_znx_add_inplace(&mut c0_big, 0, &pt.data, 0); - } - } - - // c[0] = norm(c[0]) - module.vec_znx_normalize(basek, &mut self.data, 0, &c0_big, 0, scratch_1); - } - - pub(crate) fn encrypt_pk_private, DataPk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - pk: &GLWEPublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.basek(), pk.basek()); - assert_eq!(self.n(), module.n()); - assert_eq!(pk.n(), module.n()); - assert_eq!(self.rank(), pk.rank()); - if let Some((pt, _)) = pt { - assert_eq!(pt.basek(), pk.basek()); - assert_eq!(pt.n(), module.n()); - } - } - - let basek: usize = pk.basek(); - let size_pk: usize = pk.size(); - let cols: usize = self.rank() + 1; - - // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); - - { - let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); - match pk.dist { - SecretDistribution::NONE => panic!( - "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ - Self::generate" - ), - SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - SecretDistribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu), - SecretDistribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu), - SecretDistribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu), - SecretDistribution::ZERO => {} - } - - module.svp_prepare(&mut u_dft, 0, &u, 0); - } - - // ct[i] = pk[i] * u + ei (+ m if col = i) - (0..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk); - // ci_dft = DFT(u) * DFT(pk[i]) - module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data.data, i); - - // ci_big = u * p[i] - let mut ci_big = module.vec_znx_idft_consume(ci_dft); - - // ci_big = u * pk[i] + e - ci_big.add_normal(basek, 0, pk.k(), source_xe, sigma, sigma * SIX_SIGMA); - - // ci_big = u * pk[i] + e + m (if col = i) - if let Some((pt, col)) = pt { - if col == i { - module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0); - } - } - - // ct[i] = norm(ci_big) - module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2); - }); - } -} - -impl> GLWECiphertext { - pub fn clone(&self) -> GLWECiphertext> { - GLWECiphertext { - data: self.data.clone(), - basek: self.basek(), - k: self.k(), - } - } - - pub fn decrypt + AsRef<[u8]>, DataSk: AsRef<[u8]>>( - &self, - module: &Module, - pt: &mut GLWEPlaintext, - sk: &GLWESecret, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert_eq!(sk.n(), module.n()); - } - - let cols: usize = self.rank() + 1; - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct - c0_big.zero(); - - { - (1..cols).for_each(|i| { - // ci_dft = DFT(a[i]) * DFT(s[i]) - let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); - module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); - let ci_big = module.vec_znx_idft_consume(ci_dft); - - // c0_big += a[i] * s[i] - module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); - }); - } - - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0); - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut c0_big, 0, scratch_1); - - pt.basek = self.basek(); - pt.k = pt.k().min(self.k()); - } -} - -pub trait GLWECiphertextToRef { - fn to_ref(&self) -> GLWECiphertext<&[u8]>; -} - -impl> GLWECiphertextToRef for GLWECiphertext { - fn to_ref(&self) -> GLWECiphertext<&[u8]> { - GLWECiphertext { - data: self.data.to_ref(), - basek: self.basek, - k: self.k, - } - } -} - -pub trait GLWECiphertextToMut { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; -} - -impl + AsRef<[u8]>> GLWECiphertextToMut for GLWECiphertext { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext { - data: self.data.to_mut(), - basek: self.basek, - k: self.k, - } - } -} - -impl GLWEOps for GLWECiphertext -where - D: AsRef<[u8]> + AsMut<[u8]>, - GLWECiphertext: GLWECiphertextToMut + Infos + SetMetaData, -{ -} diff --git a/core/src/keys.rs b/core/src/keys.rs index 9549f15..45bdc61 100644 --- a/core/src/keys.rs +++ b/core/src/keys.rs @@ -1,11 +1,3 @@ -use backend::{ - Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnxDft, - ZnxInfos, ZnxZero, -}; -use sampling::source::Source; - -use crate::{GLWECiphertextFourier, Infos}; - #[derive(Clone, Copy, Debug)] pub(crate) enum SecretDistribution { TernaryFixed(usize), // Ternary with fixed Hamming weight @@ -16,221 +8,3 @@ pub(crate) enum SecretDistribution { ZERO, // Debug mod NONE, // Unitialized } - -pub struct LWESecret { - pub(crate) data: ScalarZnx, - pub(crate) dist: SecretDistribution, -} - -impl LWESecret> { - pub fn alloc(n: usize) -> Self { - Self { - data: ScalarZnx::new(n, 1), - dist: SecretDistribution::NONE, - } - } -} - -impl LWESecret { - pub fn n(&self) -> usize { - self.data.n() - } - - pub fn log_n(&self) -> usize { - self.data.log_n() - } - - pub fn rank(&self) -> usize { - self.data.cols() - } -} - -impl + AsMut<[u8]>> LWESecret { - pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - self.data.fill_ternary_prob(0, prob, source); - self.dist = SecretDistribution::TernaryProb(prob); - } - - pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - self.data.fill_ternary_hw(0, hw, source); - self.dist = SecretDistribution::TernaryFixed(hw); - } - - pub fn fill_binary_prob(&mut self, prob: f64, source: &mut Source) { - self.data.fill_binary_prob(0, prob, source); - self.dist = SecretDistribution::BinaryProb(prob); - } - - pub fn fill_binary_hw(&mut self, hw: usize, source: &mut Source) { - self.data.fill_binary_hw(0, hw, source); - self.dist = SecretDistribution::BinaryFixed(hw); - } - - pub fn fill_binary_block(&mut self, block_size: usize, source: &mut Source) { - self.data.fill_binary_block(0, block_size, source); - self.dist = SecretDistribution::BinaryBlock(block_size); - } - - pub fn fill_zero(&mut self) { - self.data.zero(); - self.dist = SecretDistribution::ZERO; - } -} - -pub struct GLWESecret { - pub(crate) data: ScalarZnx, - pub(crate) data_fourier: ScalarZnxDft, - pub(crate) dist: SecretDistribution, -} - -impl GLWESecret, B> { - pub fn alloc(module: &Module, rank: usize) -> Self { - Self { - data: module.new_scalar_znx(rank), - data_fourier: module.new_scalar_znx_dft(rank), - dist: SecretDistribution::NONE, - } - } - - pub fn bytes_of(module: &Module, rank: usize) -> usize { - module.bytes_of_scalar_znx(rank) + module.bytes_of_scalar_znx_dft(rank) - } -} - -impl GLWESecret { - pub fn n(&self) -> usize { - self.data.n() - } - - pub fn log_n(&self) -> usize { - self.data.log_n() - } - - pub fn rank(&self) -> usize { - self.data.cols() - } -} - -impl + AsRef<[u8]>> GLWESecret { - pub fn fill_ternary_prob(&mut self, module: &Module, prob: f64, source: &mut Source) { - (0..self.rank()).for_each(|i| { - self.data.fill_ternary_prob(i, prob, source); - }); - self.prep_fourier(module); - self.dist = SecretDistribution::TernaryProb(prob); - } - - pub fn fill_ternary_hw(&mut self, module: &Module, hw: usize, source: &mut Source) { - (0..self.rank()).for_each(|i| { - self.data.fill_ternary_hw(i, hw, source); - }); - self.prep_fourier(module); - self.dist = SecretDistribution::TernaryFixed(hw); - } - - pub fn fill_binary_prob(&mut self, module: &Module, prob: f64, source: &mut Source) { - (0..self.rank()).for_each(|i| { - self.data.fill_binary_prob(i, prob, source); - }); - self.prep_fourier(module); - self.dist = SecretDistribution::BinaryProb(prob); - } - - pub fn fill_binary_hw(&mut self, module: &Module, hw: usize, source: &mut Source) { - (0..self.rank()).for_each(|i| { - self.data.fill_binary_hw(i, hw, source); - }); - self.prep_fourier(module); - self.dist = SecretDistribution::BinaryFixed(hw); - } - - pub fn fill_binary_block(&mut self, module: &Module, block_size: usize, source: &mut Source) { - (0..self.rank()).for_each(|i| { - self.data.fill_binary_block(i, block_size, source); - }); - self.prep_fourier(module); - self.dist = SecretDistribution::BinaryBlock(block_size); - } - - pub fn fill_zero(&mut self) { - self.data.zero(); - self.dist = SecretDistribution::ZERO; - } - - pub(crate) fn prep_fourier(&mut self, module: &Module) { - (0..self.rank()).for_each(|i| { - module.svp_prepare(&mut self.data_fourier, i, &self.data, i); - }); - } -} - -pub struct GLWEPublicKey { - pub(crate) data: GLWECiphertextFourier, - pub(crate) dist: SecretDistribution, -} - -impl GLWEPublicKey, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { - Self { - data: GLWECiphertextFourier::alloc(module, basek, k, rank), - dist: SecretDistribution::NONE, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GLWECiphertextFourier::, B>::bytes_of(module, basek, k, rank) - } -} - -impl Infos for GLWEPublicKey { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data.data - } - - fn basek(&self) -> usize { - self.data.basek - } - - fn k(&self) -> usize { - self.data.k - } -} - -impl GLWEPublicKey { - pub fn rank(&self) -> usize { - self.cols() - 1 - } -} - -impl + AsMut<[u8]>> GLWEPublicKey { - pub fn generate_from_sk>( - &mut self, - module: &Module, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - ) { - #[cfg(debug_assertions)] - { - match sk.dist { - SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), - _ => {} - } - } - - // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_sk_scratch_space( - module, - self.basek(), - self.k(), - self.rank(), - )); - - self.data - .encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch.borrow()); - self.dist = sk.dist; - } -} diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index fd4da76..201ae22 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -1,7 +1,7 @@ use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, Scratch, ZnxZero}; use sampling::source::Source; -use crate::{GGLWECiphertext, GGSWCiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow}; +use crate::{FourierGLWECiphertext, GGLWECiphertext, GGSWCiphertext, GLWESecret, GetRow, Infos, ScratchCore, SetRow}; pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); @@ -73,7 +73,7 @@ impl> GetRow for GLWESwitchingKey { module: &Module, row_i: usize, col_j: usize, - res: &mut GLWECiphertextFourier, + res: &mut FourierGLWECiphertext, ) { module.mat_znx_dft_get_row(&mut res.data, &self.0.data, row_i, col_j); } @@ -85,7 +85,7 @@ impl + AsRef<[u8]>> SetRow for GLWESwitchingKey module: &Module, row_i: usize, col_j: usize, - a: &GLWECiphertextFourier, + a: &FourierGLWECiphertext, ) { module.mat_znx_dft_set_row(&mut self.0.data, row_i, col_j, &a.data); } @@ -110,10 +110,10 @@ impl GLWESwitchingKey, FFT64> { rank_in: usize, rank_out: usize, ) -> usize { - let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank_in); - let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank_out); + let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank_in); + let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank_out); let ksk: usize = - GLWECiphertextFourier::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out); + FourierGLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out); tmp_in + tmp_out + ksk } @@ -125,8 +125,8 @@ impl GLWESwitchingKey, FFT64> { digits: usize, rank: usize, ) -> usize { - let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); - let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); + let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ksk: usize = FourierGLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); tmp + ksk } @@ -139,9 +139,9 @@ impl GLWESwitchingKey, FFT64> { digits: usize, rank: usize, ) -> usize { - let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank); - let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); - let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); + let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); + let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ggsw: usize = FourierGLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); tmp_in + tmp_out + ggsw } @@ -153,9 +153,9 @@ impl GLWESwitchingKey, FFT64> { digits: usize, rank: usize, ) -> usize { - let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); let ggsw: usize = - GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); + FourierGLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); tmp + ggsw } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 1bffc5a..ef181a1 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,39 +1,36 @@ pub mod automorphism; pub mod blind_rotation; pub mod elem; -pub mod gglwe_ciphertext; -pub mod ggsw_ciphertext; -pub mod glwe_ciphertext; -pub mod glwe_ciphertext_fourier; -pub mod glwe_ops; -pub mod glwe_packing; -pub mod glwe_plaintext; +pub mod fourier_glwe; +pub mod gglwe; +pub mod ggsw; +pub mod glwe; pub mod keys; pub mod keyswitch_key; +pub mod lwe; pub mod tensor_key; #[cfg(test)] mod test_fft64; -pub mod trace; +mod utils; pub use automorphism::*; use backend::Backend; use backend::FFT64; use backend::Module; pub use elem::*; -pub use gglwe_ciphertext::*; -pub use ggsw_ciphertext::*; -pub use glwe_ciphertext::*; -pub use glwe_ciphertext_fourier::*; -pub use glwe_ops::*; -pub use glwe_packing::*; -pub use glwe_plaintext::*; -pub use keys::*; +pub use fourier_glwe::*; +pub use gglwe::*; +pub use ggsw::*; +pub use glwe::*; pub use keyswitch_key::*; +pub use lwe::*; pub use tensor_key::*; pub use backend::Scratch; pub use backend::ScratchOwned; +use crate::keys::SecretDistribution; + pub(crate) const SIX_SIGMA: f64 = 6.0; pub trait ScratchCore { @@ -64,7 +61,7 @@ pub trait ScratchCore { basek: usize, k: usize, rank: usize, - ) -> (GLWECiphertextFourier<&mut [u8], B>, &mut Self); + ) -> (FourierGLWECiphertext<&mut [u8], B>, &mut Self); fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8], B>, &mut Self); fn tmp_glwe_pk( &mut self, @@ -181,9 +178,9 @@ impl ScratchCore for Scratch { basek: usize, k: usize, rank: usize, - ) -> (GLWECiphertextFourier<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, k.div_ceil(basek)); - (GLWECiphertextFourier { data, basek, k }, scratch) + ) -> (FourierGLWECiphertext<&mut [u8], FFT64>, &mut Self) { + let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, div_ceil(k, basek)); + (FourierGLWECiphertext { data, basek, k }, scratch) } fn tmp_glwe_pk( diff --git a/core/src/lwe.rs b/core/src/lwe.rs new file mode 100644 index 0000000..3f0d749 --- /dev/null +++ b/core/src/lwe.rs @@ -0,0 +1,64 @@ +use backend::{ScalarZnx, ZnxInfos, ZnxZero}; +use sampling::source::Source; + +use crate::SecretDistribution; + +pub struct LWESecret { + pub(crate) data: ScalarZnx, + pub(crate) dist: SecretDistribution, +} + +impl LWESecret> { + pub fn alloc(n: usize) -> Self { + Self { + data: ScalarZnx::new(n, 1), + dist: SecretDistribution::NONE, + } + } +} + +impl LWESecret { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl + AsMut<[u8]>> LWESecret { + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { + self.data.fill_ternary_prob(0, prob, source); + self.dist = SecretDistribution::TernaryProb(prob); + } + + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + self.data.fill_ternary_hw(0, hw, source); + self.dist = SecretDistribution::TernaryFixed(hw); + } + + pub fn fill_binary_prob(&mut self, prob: f64, source: &mut Source) { + self.data.fill_binary_prob(0, prob, source); + self.dist = SecretDistribution::BinaryProb(prob); + } + + pub fn fill_binary_hw(&mut self, hw: usize, source: &mut Source) { + self.data.fill_binary_hw(0, hw, source); + self.dist = SecretDistribution::BinaryFixed(hw); + } + + pub fn fill_binary_block(&mut self, block_size: usize, source: &mut Source) { + self.data.fill_binary_block(0, block_size, source); + self.dist = SecretDistribution::BinaryBlock(block_size); + } + + pub fn fill_zero(&mut self) { + self.data.zero(); + self.dist = SecretDistribution::ZERO; + } +} diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/test_fft64/automorphism_key.rs index f23b619..93cd5d6 100644 --- a/core/src/test_fft64/automorphism_key.rs +++ b/core/src/test_fft64/automorphism_key.rs @@ -2,7 +2,8 @@ use backend::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - AutomorphismKey, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GetRow, Infos, test_fft64::log2_std_noise_gglwe_product, + AutomorphismKey, FourierGLWECiphertext, GLWEPlaintext, GLWESecret, GetRow, Infos, div_ceil, + test_fft64::log2_std_noise_gglwe_product, }; #[test] @@ -69,7 +70,7 @@ fn test_automorphism( let mut scratch: ScratchOwned = ScratchOwned::new( AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | AutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), ); @@ -101,7 +102,7 @@ fn test_automorphism( // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key_out.automorphism(&module, &auto_key_in, &auto_key_apply, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut sk_auto: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -183,7 +184,7 @@ fn test_automorphism_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_in) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_in) | AutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), ); @@ -215,7 +216,7 @@ fn test_automorphism_inplace( // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_in, rank); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); let mut sk_auto: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index 91798d2..bde6308 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -2,7 +2,7 @@ use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToMut, ScratchO use sampling::source::Source; use crate::{ - GGSWCiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, + FourierGLWECiphertext, GGSWCiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, div_ceil, test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, }; @@ -145,7 +145,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, digits: usize, rank let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ksk), + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ksk), ); let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); @@ -164,8 +164,8 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, digits: usize, rank scratch.borrow(), ); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = + FourierGLWECiphertext::alloc(&module, basek, k_ksk, rank_out); (0..ksk.rank_in()).for_each(|col_i| { (0..ksk.rows()).for_each(|row_i| { @@ -234,7 +234,7 @@ fn test_key_switch( let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in_s0s1 | rank_out_s0s1) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWESwitchingKey::keyswitch_scratch_space( &module, basek, @@ -281,8 +281,8 @@ fn test_key_switch( // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_out, rank_out_s1s2); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = + FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out_s1s2); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { @@ -346,7 +346,7 @@ fn test_key_switch_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ksk) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ksk) | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, rank_out), ); @@ -386,7 +386,7 @@ fn test_key_switch_inplace( let ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = ct_gglwe_s0s1; - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank_out); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { @@ -455,7 +455,7 @@ fn test_external_product( let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_in, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWESwitchingKey::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank_out) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), ); @@ -494,7 +494,7 @@ fn test_external_product( // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe_out.external_product(&module, &ct_gglwe_in, &ct_rgsw, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank_out); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); (0..rank_in).for_each(|i| { @@ -575,7 +575,7 @@ fn test_external_product_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ct, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) | GLWESwitchingKey::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank_out) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), ); @@ -614,7 +614,7 @@ fn test_external_product_inplace( // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank_out); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); (0..rank_in).for_each(|i| { diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 3e6ee8b..bcfb950 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -5,7 +5,7 @@ use backend::{ use sampling::source::Source; use crate::{ - GGSWCiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, TensorKey, + FourierGLWECiphertext, GGSWCiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, TensorKey, automorphism::AutomorphismKey, test_fft64::{noise_ggsw_keyswitch, noise_ggsw_product}, }; @@ -139,7 +139,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, digits: usize, rank: us let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k), + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -155,7 +155,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, digits: usize, rank: us scratch.borrow(), ); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); @@ -221,7 +221,7 @@ fn test_keyswitch( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_scratch_space( @@ -269,7 +269,7 @@ fn test_keyswitch( ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); @@ -348,7 +348,7 @@ fn test_keyswitch_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), @@ -394,7 +394,7 @@ fn test_keyswitch_inplace( ct.keyswitch_inplace(&module, &ksk, &tsk, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); @@ -483,7 +483,7 @@ fn test_automorphism( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_scratch_space( @@ -530,7 +530,7 @@ fn test_automorphism( module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); @@ -608,7 +608,7 @@ fn test_automorphism_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), @@ -653,7 +653,7 @@ fn test_automorphism_inplace( module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); @@ -737,7 +737,7 @@ fn test_external_product( pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank) | GGSWCiphertext::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank), ); @@ -767,7 +767,7 @@ fn test_external_product( ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); @@ -857,7 +857,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank) | GGSWCiphertext::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank), ); @@ -887,7 +887,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size()); diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 4859ad1..7a32343 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -6,7 +6,7 @@ use itertools::izip; use sampling::source::Source; use crate::{ - GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, + FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, automorphism::AutomorphismKey, keyswitch_key::GLWESwitchingKey, test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, @@ -207,11 +207,11 @@ fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, ran let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut ct_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) - | GLWECiphertextFourier::encrypt_sk_scratch_space(&module, basek, k_ct, rank), + FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) + | FourierGLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank), ); ct_dft.encrypt_zero_sk( diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index 48a0f0d..fd54f57 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -1,5 +1,5 @@ use crate::{ - GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, + FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, div_ceil, test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, }; use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; @@ -90,10 +90,10 @@ fn test_keyswitch( let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); - let mut ct_glwe_dft_in: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_in, rank_in); + let mut ct_glwe_dft_in: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank_in); let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut ct_glwe_dft_out: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_out, rank_out); + let mut ct_glwe_dft_out: FourierGLWECiphertext, FFT64> = + FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); @@ -110,7 +110,7 @@ fn test_keyswitch( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) | GLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_in) - | GLWECiphertextFourier::keyswitch_scratch_space( + | FourierGLWECiphertext::keyswitch_scratch_space( &module, basek, ct_glwe_out.k(), @@ -185,7 +185,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); @@ -202,7 +202,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), digits, rank), + | FourierGLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), digits, rank), ); let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -278,8 +278,8 @@ fn test_external_product( let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); - let mut ct_in_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_in, rank); - let mut ct_out_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut ct_in_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_out_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); @@ -304,7 +304,7 @@ fn test_external_product( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | GLWECiphertextFourier::external_product_scratch_space( + | FourierGLWECiphertext::external_product_scratch_space( &module, basek, ct_out.k(), @@ -384,7 +384,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); @@ -409,7 +409,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertextFourier::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), digits, rank), + | FourierGLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); diff --git a/core/src/test_fft64/tensor_key.rs b/core/src/test_fft64/tensor_key.rs index 579fec4..e9827cb 100644 --- a/core/src/test_fft64/tensor_key.rs +++ b/core/src/test_fft64/tensor_key.rs @@ -1,7 +1,7 @@ use backend::{FFT64, Module, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; -use crate::{GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GetRow, Infos, TensorKey}; +use crate::{FourierGLWECiphertext, GLWEPlaintext, GLWESecret, GetRow, Infos, TensorKey}; #[test] fn encrypt_sk() { @@ -42,7 +42,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize scratch.borrow(), ); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); let mut sk_ij = GLWESecret::alloc(&module, 1); From 4d4b43a4e52e53550fe8f4021e482220be31640c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 13 Jun 2025 09:20:14 +0200 Subject: [PATCH 06/23] Re-organized code for glwe testing --- core/src/glwe/mod.rs | 3 + core/src/glwe/test_fft64/automorphism.rs | 1 + core/src/glwe/test_fft64/encryption.rs | 180 ++++++++++++++++++ .../test_fft64/external_product.rs} | 0 core/src/glwe/test_fft64/keyswitch.rs | 1 + core/src/glwe/test_fft64/mod.rs | 6 + .../test_fft64/packing.rs} | 0 core/src/{ => glwe}/test_fft64/trace.rs | 0 core/src/test_fft64/mod.rs | 3 - 9 files changed, 191 insertions(+), 3 deletions(-) create mode 100644 core/src/glwe/test_fft64/automorphism.rs create mode 100644 core/src/glwe/test_fft64/encryption.rs rename core/src/{test_fft64/glwe.rs => glwe/test_fft64/external_product.rs} (100%) create mode 100644 core/src/glwe/test_fft64/keyswitch.rs create mode 100644 core/src/glwe/test_fft64/mod.rs rename core/src/{test_fft64/glwe_packing.rs => glwe/test_fft64/packing.rs} (100%) rename core/src/{ => glwe}/test_fft64/trace.rs (100%) diff --git a/core/src/glwe/mod.rs b/core/src/glwe/mod.rs index ed3395d..7f0c7aa 100644 --- a/core/src/glwe/mod.rs +++ b/core/src/glwe/mod.rs @@ -29,3 +29,6 @@ pub use public_key::*; pub use secret::*; #[allow(unused_imports)] pub use trace::*; + +#[cfg(test)] +mod test_fft64; diff --git a/core/src/glwe/test_fft64/automorphism.rs b/core/src/glwe/test_fft64/automorphism.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/core/src/glwe/test_fft64/automorphism.rs @@ -0,0 +1 @@ + diff --git a/core/src/glwe/test_fft64/encryption.rs b/core/src/glwe/test_fft64/encryption.rs new file mode 100644 index 0000000..06490c5 --- /dev/null +++ b/core/src/glwe/test_fft64/encryption.rs @@ -0,0 +1,180 @@ +use backend::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; +use itertools::izip; +use sampling::source::Source; + +use crate::{FourierGLWECiphertext, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos}; + +#[test] +fn encrypt_sk() { + let log_n: usize = 8; + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(log_n, 8, 54, 30, 3.2, rank); + }); +} + +#[test] +fn encrypt_zero_sk() { + let log_n: usize = 8; + (1..4).for_each(|rank| { + println!("test encrypt_zero_sk rank: {}", rank); + test_encrypt_zero_sk(log_n, 8, 64, 3.2, rank); + }); +} + +#[test] +fn encrypt_pk() { + let log_n: usize = 8; + (1..4).for_each(|rank| { + println!("test encrypt_pk rank: {}", rank); + test_encrypt_pk(log_n, 8, 64, 64, 3.2, rank) + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_pt); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()), + ); + + let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); + + pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10); + + ct.encrypt_sk( + &module, + &pt, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt.data.zero(); + + ct.decrypt(&module, &mut pt, &sk, scratch.borrow()); + + let mut data_have: Vec = vec![0i64; module.n()]; + + pt.data + .decode_vec_i64(0, basek, pt.size() * basek, &mut data_have); + + // TODO: properly assert the decryption noise through std(dec(ct) - pt) + let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64; + izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { + let b_scaled = (*b as f64) / scale; + assert!( + (*a as f64 - b_scaled).abs() < 0.1, + "{} {}", + *a as f64, + b_scaled + ) + }); +} + +fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + + let mut ct_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); + + let mut scratch: ScratchOwned = ScratchOwned::new( + FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) + | FourierGLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank), + ); + + ct_dft.encrypt_zero_sk( + &module, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + ct_dft.decrypt(&module, &mut pt, &sk, scratch.borrow()); + + assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); +} + +fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + + let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::alloc(&module, basek, k_pk, rank); + pk.generate_from_sk(&module, &sk, &mut source_xa, &mut source_xe, sigma); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::encrypt_pk_scratch_space(&module, basek, pk.k()), + ); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0); + + pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10); + + ct.encrypt_pk( + &module, + &pt_want, + &pk, + &mut source_xu, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); + + let noise_have: f64 = pt_want.data.std(0, basek).log2(); + let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); + + assert!( + (noise_have - noise_want).abs() < 0.2, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/test_fft64/glwe.rs b/core/src/glwe/test_fft64/external_product.rs similarity index 100% rename from core/src/test_fft64/glwe.rs rename to core/src/glwe/test_fft64/external_product.rs diff --git a/core/src/glwe/test_fft64/keyswitch.rs b/core/src/glwe/test_fft64/keyswitch.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/core/src/glwe/test_fft64/keyswitch.rs @@ -0,0 +1 @@ + diff --git a/core/src/glwe/test_fft64/mod.rs b/core/src/glwe/test_fft64/mod.rs new file mode 100644 index 0000000..e3f5425 --- /dev/null +++ b/core/src/glwe/test_fft64/mod.rs @@ -0,0 +1,6 @@ +pub mod automorphism; +pub mod encryption; +pub mod external_product; +pub mod keyswitch; +pub mod packing; +pub mod trace; diff --git a/core/src/test_fft64/glwe_packing.rs b/core/src/glwe/test_fft64/packing.rs similarity index 100% rename from core/src/test_fft64/glwe_packing.rs rename to core/src/glwe/test_fft64/packing.rs diff --git a/core/src/test_fft64/trace.rs b/core/src/glwe/test_fft64/trace.rs similarity index 100% rename from core/src/test_fft64/trace.rs rename to core/src/glwe/test_fft64/trace.rs diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index 73a58e9..4c0f513 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -1,11 +1,8 @@ mod automorphism_key; mod gglwe; mod ggsw; -mod glwe; mod glwe_fourier; -mod glwe_packing; mod tensor_key; -mod trace; pub(crate) fn var_noise_gglwe_product( n: f64, From e8cfb5e2ab647b8d4a64078444eac64d5e21af0b Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 13 Jun 2025 18:57:28 +0200 Subject: [PATCH 07/23] Reorganized other crates --- backend/src/scalar_znx_dft_ops.rs | 4 +- core/benches/external_product_glwe_fft64.rs | 20 +- core/benches/keyswitch_glwe_fft64.rs | 30 +- core/src/blind_rotation/key.rs | 12 +- core/src/{keys.rs => dist.rs} | 2 +- core/src/fourier_glwe/ciphertext.rs | 45 ++ core/src/fourier_glwe/decryption.rs | 84 +++ core/src/fourier_glwe/encryption.rs | 32 + core/src/fourier_glwe/external_product.rs | 129 ++++ core/src/fourier_glwe/keyswitch.rs | 56 ++ core/src/fourier_glwe/mod.rs | 12 + core/src/fourier_glwe/secret.rs | 58 ++ .../test_fft64/external_product.rs | 246 +++++++ core/src/fourier_glwe/test_fft64/keyswitch.rs | 235 +++++++ core/src/fourier_glwe/test_fft64/mod.rs | 2 + core/src/gglwe/automorphism.rs | 136 ++++ core/src/gglwe/automorphism_key.rs | 83 +++ core/src/gglwe/ciphertext.rs | 131 ++++ core/src/gglwe/encryption.rs | 253 +++++++ core/src/gglwe/external_product.rs | 162 +++++ core/src/gglwe/keyswitch.rs | 163 +++++ core/src/gglwe/keyswitch_key.rs | 91 +++ core/src/gglwe/mod.rs | 16 + core/src/{ => gglwe}/tensor_key.rs | 60 +- .../test_fft64/automorphism_key.rs | 51 +- core/src/{ => gglwe}/test_fft64/gglwe.rs | 88 +-- core/src/gglwe/test_fft64/mod.rs | 3 + core/src/{ => gglwe}/test_fft64/tensor_key.rs | 29 +- core/src/{ggsw.rs => ggsw/ciphertext.rs} | 20 +- core/src/ggsw/mod.rs | 6 + core/src/{ => ggsw}/test_fft64/ggsw.rs | 133 ++-- core/src/ggsw/test_fft64/mod.rs | 1 + core/src/glwe/automorphism.rs | 18 +- core/src/glwe/decryption.rs | 6 +- core/src/glwe/encryption.rs | 24 +- core/src/glwe/mod.rs | 25 +- core/src/glwe/packing.rs | 14 +- core/src/glwe/plaintext.rs | 6 +- core/src/glwe/public_key.rs | 10 +- core/src/glwe/secret.rs | 63 +- core/src/glwe/test_fft64/automorphism.rs | 223 ++++++ core/src/glwe/test_fft64/encryption.rs | 29 +- core/src/glwe/test_fft64/external_product.rs | 644 +----------------- core/src/glwe/test_fft64/keyswitch.rs | 226 ++++++ core/src/glwe/test_fft64/packing.rs | 27 +- core/src/glwe/test_fft64/trace.rs | 23 +- core/src/glwe/trace.rs | 6 +- core/src/keyswitch_key.rs | 343 ---------- core/src/lib.rs | 56 +- core/src/lwe/mod.rs | 3 + core/src/{lwe.rs => lwe/secret.rs} | 18 +- core/src/{test_fft64/mod.rs => noise.rs} | 10 +- 52 files changed, 2787 insertions(+), 1380 deletions(-) rename core/src/{keys.rs => dist.rs} (92%) create mode 100644 core/src/fourier_glwe/ciphertext.rs create mode 100644 core/src/fourier_glwe/decryption.rs create mode 100644 core/src/fourier_glwe/encryption.rs create mode 100644 core/src/fourier_glwe/external_product.rs create mode 100644 core/src/fourier_glwe/keyswitch.rs create mode 100644 core/src/fourier_glwe/mod.rs create mode 100644 core/src/fourier_glwe/secret.rs create mode 100644 core/src/fourier_glwe/test_fft64/external_product.rs create mode 100644 core/src/fourier_glwe/test_fft64/keyswitch.rs create mode 100644 core/src/fourier_glwe/test_fft64/mod.rs create mode 100644 core/src/gglwe/automorphism.rs create mode 100644 core/src/gglwe/automorphism_key.rs create mode 100644 core/src/gglwe/ciphertext.rs create mode 100644 core/src/gglwe/encryption.rs create mode 100644 core/src/gglwe/external_product.rs create mode 100644 core/src/gglwe/keyswitch.rs create mode 100644 core/src/gglwe/keyswitch_key.rs create mode 100644 core/src/gglwe/mod.rs rename core/src/{ => gglwe}/tensor_key.rs (52%) rename core/src/{ => gglwe}/test_fft64/automorphism_key.rs (76%) rename core/src/{ => gglwe}/test_fft64/gglwe.rs (86%) create mode 100644 core/src/gglwe/test_fft64/mod.rs rename core/src/{ => gglwe}/test_fft64/tensor_key.rs (65%) rename core/src/{ggsw.rs => ggsw/ciphertext.rs} (97%) create mode 100644 core/src/ggsw/mod.rs rename core/src/{ => ggsw}/test_fft64/ggsw.rs (86%) create mode 100644 core/src/ggsw/test_fft64/mod.rs delete mode 100644 core/src/keyswitch_key.rs create mode 100644 core/src/lwe/mod.rs rename core/src/{lwe.rs => lwe/secret.rs} (73%) rename core/src/{test_fft64/mod.rs => noise.rs} (97%) diff --git a/backend/src/scalar_znx_dft_ops.rs b/backend/src/scalar_znx_dft_ops.rs index 6d227c7..c89808d 100644 --- a/backend/src/scalar_znx_dft_ops.rs +++ b/backend/src/scalar_znx_dft_ops.rs @@ -29,7 +29,7 @@ pub trait ScalarZnxDftOps { R: VecZnxDftToMut, A: ScalarZnxDftToRef; - fn svp_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + fn scalar_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where R: ScalarZnxToMut, A: ScalarZnxDftToRef; @@ -50,7 +50,7 @@ impl ScalarZnxDftAlloc for Module { } impl ScalarZnxDftOps for Module { - fn svp_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + fn scalar_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where R: ScalarZnxToMut, A: ScalarZnxDftToRef, diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index e81829c..c48c626 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -1,5 +1,5 @@ use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned}; -use core::{GGSWCiphertext, GLWECiphertext, GLWESecret, Infos}; +use core::{FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWESecret, Infos}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use sampling::source::Source; use std::hint::black_box; @@ -52,13 +52,14 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw.encrypt_sk( &module, &pt_rgsw, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -67,7 +68,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { ct_glwe_in.encrypt_zero_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -134,13 +135,14 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw.encrypt_sk( &module, &pt_rgsw, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -149,7 +151,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { ct_glwe.encrypt_zero_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs index 66a1d02..9de1e9c 100644 --- a/core/benches/keyswitch_glwe_fft64.rs +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -1,5 +1,5 @@ use backend::{FFT64, Module, ScratchOwned}; -use core::{AutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos}; +use core::{FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use sampling::source::Source; use std::{hint::black_box, time::Duration}; @@ -32,7 +32,8 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let rows: usize = (p.k_ct_in + (p.basek * digits) - 1) / (p.basek * digits); let sigma: f64 = 3.2; - let mut ksk: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_grlwe, rows, digits, rank_out); + let mut ksk: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_grlwe, rows, digits, rank_out); let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_in, rank_in); let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_out, rank_out); @@ -55,11 +56,12 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); ksk.generate_from_sk( &module, @@ -73,7 +75,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { ct_in.encrypt_zero_sk( &module, - &sk_in, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, @@ -146,16 +148,18 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(&&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(&&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); ksk.generate_from_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -164,7 +168,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { ct.encrypt_zero_sk( &module, - &sk_in, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index ff4a887..270080d 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -1,16 +1,16 @@ use backend::{Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}; use sampling::source::Source; -use crate::{AutomorphismKey, GGSWCiphertext, GLWESecret, LWESecret, SecretDistribution}; +use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, GLWEAutomorphismKey, LWESecret}; pub struct BlindRotationKeyCGGI { pub(crate) data: Vec, B>>, - pub(crate) dist: SecretDistribution, + pub(crate) dist: Distribution, } pub struct BlindRotationKeyFHEW { pub(crate) data: Vec, B>>, - pub(crate) auto: Vec, B>>, + pub(crate) auto: Vec, B>>, } impl BlindRotationKeyCGGI { @@ -19,7 +19,7 @@ impl BlindRotationKeyCGGI { (0..lwe_degree).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank))); Self { data, - dist: SecretDistribution::NONE, + dist: Distribution::NONE, } } @@ -30,7 +30,7 @@ impl BlindRotationKeyCGGI { pub fn generate_from_sk( &mut self, module: &Module, - sk_glwe: &GLWESecret, + sk_glwe: &FourierGLWESecret, sk_lwe: &LWESecret, source_xa: &mut Source, source_xe: &mut Source, @@ -46,7 +46,7 @@ impl BlindRotationKeyCGGI { assert_eq!(sk_glwe.n(), module.n()); assert_eq!(sk_glwe.rank(), self.data[0].rank()); match sk_lwe.dist { - SecretDistribution::BinaryBlock(_) | SecretDistribution::BinaryFixed(_) | SecretDistribution::BinaryProb(_) => {} + Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) => {} _ => panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb"), } } diff --git a/core/src/keys.rs b/core/src/dist.rs similarity index 92% rename from core/src/keys.rs rename to core/src/dist.rs index 45bdc61..4a97369 100644 --- a/core/src/keys.rs +++ b/core/src/dist.rs @@ -1,5 +1,5 @@ #[derive(Clone, Copy, Debug)] -pub(crate) enum SecretDistribution { +pub(crate) enum Distribution { TernaryFixed(usize), // Ternary with fixed Hamming weight TernaryProb(f64), // Ternary with probabilistic Hamming weight BinaryFixed(usize), // Binary with fixed Hamming weight diff --git a/core/src/fourier_glwe/ciphertext.rs b/core/src/fourier_glwe/ciphertext.rs new file mode 100644 index 0000000..425191f --- /dev/null +++ b/core/src/fourier_glwe/ciphertext.rs @@ -0,0 +1,45 @@ +use backend::{Backend, Module, VecZnxDft, VecZnxDftAlloc}; + +use crate::{Infos, div_ceil}; + +pub struct FourierGLWECiphertext { + pub data: VecZnxDft, + pub basek: usize, + pub k: usize, +} + +impl FourierGLWECiphertext, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx_dft(rank + 1, div_ceil(k, basek)), + basek: basek, + k: k, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + module.bytes_of_vec_znx_dft(rank + 1, div_ceil(k, basek)) + } +} + +impl Infos for FourierGLWECiphertext { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl FourierGLWECiphertext { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} diff --git a/core/src/fourier_glwe/decryption.rs b/core/src/fourier_glwe/decryption.rs new file mode 100644 index 0000000..882be61 --- /dev/null +++ b/core/src/fourier_glwe/decryption.rs @@ -0,0 +1,84 @@ +use backend::{ + FFT64, Module, ScalarZnxDftOps, Scratch, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, + VecZnxDftOps, ZnxZero, +}; + +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, Infos, div_ceil}; + +impl FourierGLWECiphertext, FFT64> { + pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + let size: usize = div_ceil(k, basek); + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, size) + | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, size) + } +} + +impl> FourierGLWECiphertext { + pub fn decrypt + AsMut<[u8]>, DataSk: AsRef<[u8]>>( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk: &FourierGLWESecret, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk.n(), module.n()); + } + + let cols = self.rank() + 1; + + let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + pt_big.zero(); + + { + (1..cols).for_each(|i| { + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.svp_apply(&mut ci_dft, 0, &sk.data, i - 1, &self.data, i); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0); + }); + } + + { + let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size()); + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_idft(&mut c0_big, 0, &self.data, 0, scratch_2); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0); + } + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut pt_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } + + #[allow(dead_code)] + pub(crate) fn idft + AsMut<[u8]>>( + &self, + module: &Module, + res: &mut GLWECiphertext, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size); + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_idft(&mut res_big, 0, &self.data, i, scratch1); + module.vec_znx_big_normalize(self.basek(), &mut res.data, i, &res_big, 0, scratch1); + }); + } +} diff --git a/core/src/fourier_glwe/encryption.rs b/core/src/fourier_glwe/encryption.rs new file mode 100644 index 0000000..d23ff4a --- /dev/null +++ b/core/src/fourier_glwe/encryption.rs @@ -0,0 +1,32 @@ +use backend::{FFT64, Module, Scratch, VecZnxAlloc, VecZnxBigScratch, VecZnxDftOps}; +use sampling::source::Source; + +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, Infos, ScratchCore, div_ceil}; + +impl FourierGLWECiphertext, FFT64> { + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + module.bytes_of_vec_znx(1, div_ceil(k, basek)) + + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + module.bytes_of_vec_znx(rank + 1, div_ceil(k, basek)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + } +} + +impl + AsRef<[u8]>> FourierGLWECiphertext { + pub fn encrypt_zero_sk>( + &mut self, + module: &Module, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); + tmp_ct.encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch1); + tmp_ct.dft(module, self); + } +} diff --git a/core/src/fourier_glwe/external_product.rs b/core/src/fourier_glwe/external_product.rs new file mode 100644 index 0000000..116416b --- /dev/null +++ b/core/src/fourier_glwe/external_product.rs @@ -0,0 +1,129 @@ +use backend::{ + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDftAlloc, VecZnxDftOps, +}; + +use crate::{FourierGLWECiphertext, GGSWCiphertext, Infos, div_ceil}; + +impl FourierGLWECiphertext, FFT64> { + // WARNING TODO: UPDATE + pub fn external_product_scratch_space( + module: &Module, + basek: usize, + _k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + let ggsw_size: usize = div_ceil(k_ggsw, basek); + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); + let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); + let ggsw_size: usize = div_ceil(k_ggsw, basek); + let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + + module.vmp_apply_tmp_bytes(ggsw_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); + let res_small: usize = module.bytes_of_vec_znx(rank + 1, ggsw_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + res_dft + (vmp | (res_small + normalize)) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) + } +} + +impl + AsRef<[u8]>> FourierGLWECiphertext { + pub fn external_product, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &FourierGLWECiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= FourierGLWECiphertext::external_product_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank(), + ) + ); + } + + let cols: usize = rhs.rank() + 1; + let digits = rhs.digits(); + + // Space for VMP result in DFT domain and high precision + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); + + { + (0..digits).for_each(|di| { + a_dft.set_size((lhs.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + + (0..cols).for_each(|col_i| { + module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); + } + }); + } + + // VMP result in high precision + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + // Space for VMP result normalized + let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); + module.vec_znx_dft(1, 0, &mut self.data, i, &res_small, i); + }); + } + + pub fn external_product_inplace>( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; + self.external_product(&module, &*self_ptr, rhs, scratch); + } + } +} diff --git a/core/src/fourier_glwe/keyswitch.rs b/core/src/fourier_glwe/keyswitch.rs new file mode 100644 index 0000000..3abb26e --- /dev/null +++ b/core/src/fourier_glwe/keyswitch.rs @@ -0,0 +1,56 @@ +use backend::{FFT64, Module, Scratch}; + +use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESwitchingKey, Infos, ScratchCore}; + +impl FourierGLWECiphertext, FFT64> { + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + GLWECiphertext::bytes_of(module, basek, k_out, rank_out) + + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) + } +} + +impl + AsRef<[u8]>> FourierGLWECiphertext { + pub fn keyswitch, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &FourierGLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); + tmp_ct.keyswitch_from_fourier(module, lhs, rhs, scratch1); + tmp_ct.dft(module, self); + } + + pub fn keyswitch_inplace>( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } + } +} diff --git a/core/src/fourier_glwe/mod.rs b/core/src/fourier_glwe/mod.rs new file mode 100644 index 0000000..35c9905 --- /dev/null +++ b/core/src/fourier_glwe/mod.rs @@ -0,0 +1,12 @@ +pub mod ciphertext; +pub mod decryption; +pub mod encryption; +pub mod external_product; +pub mod keyswitch; +pub mod secret; + +pub use ciphertext::FourierGLWECiphertext; +pub use secret::FourierGLWESecret; + +#[cfg(test)] +pub mod test_fft64; diff --git a/core/src/fourier_glwe/secret.rs b/core/src/fourier_glwe/secret.rs new file mode 100644 index 0000000..0f28939 --- /dev/null +++ b/core/src/fourier_glwe/secret.rs @@ -0,0 +1,58 @@ +use backend::{Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ZnxInfos}; + +use crate::{GLWESecret, dist::Distribution}; + +pub struct FourierGLWESecret { + pub(crate) data: ScalarZnxDft, + pub(crate) dist: Distribution, +} + +impl FourierGLWESecret, B> { + pub fn alloc(module: &Module, rank: usize) -> Self { + Self { + data: module.new_scalar_znx_dft(rank), + dist: Distribution::NONE, + } + } + + pub fn bytes_of(module: &Module, rank: usize) -> usize { + module.bytes_of_scalar_znx_dft(rank) + } +} + +impl FourierGLWESecret, FFT64> { + pub fn from(module: &Module, sk: &GLWESecret) -> Self + where + D: AsRef<[u8]>, + { + let mut sk_dft: FourierGLWESecret, FFT64> = Self::alloc(module, sk.rank()); + sk_dft.set(module, sk); + sk_dft + } +} + +impl FourierGLWESecret { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl + AsRef<[u8]>> FourierGLWESecret { + pub(crate) fn set(&mut self, module: &Module, sk: &GLWESecret) + where + D: AsRef<[u8]>, + { + (0..self.rank()).for_each(|i| { + module.svp_prepare(&mut self.data, i, &sk.data, i); + }); + self.dist = sk.dist + } +} diff --git a/core/src/fourier_glwe/test_fft64/external_product.rs b/core/src/fourier_glwe/test_fft64/external_product.rs new file mode 100644 index 0000000..2228d29 --- /dev/null +++ b/core/src/fourier_glwe/test_fft64/external_product.rs @@ -0,0 +1,246 @@ +use crate::{ + FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, Infos, + div_ceil, noise::noise_ggsw_product, +}; +use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use sampling::source::Source; + +#[test] +fn apply() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = div_ceil(k_in, basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + let k_out: usize = k_ggsw; // Better capture noise. + test_apply(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); + }); + }); +} + +#[test] +fn apply_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = div_ceil(k_ct, basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + test_apply_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); + }); + }); +} + +fn test_apply(log_n: usize, basek: usize, k_out: usize, k_in: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = div_ceil(k_in, digits * basek); + + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); + let mut ct_in_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_out_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.data.at_mut(0, 0)[1] = 1; + + let k: i64 = 1; + + pt_rgsw.raw_mut()[0] = 1; // X^{0} + module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + | FourierGLWECiphertext::external_product_scratch_space( + &module, + basek, + ct_out.k(), + ct_in.k(), + ct_ggsw.k(), + digits, + rank, + ), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + ct_ggsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.dft(&module, &mut ct_in_dft); + ct_out_dft.external_product(&module, &ct_in_dft, &ct_ggsw, scratch.borrow()); + ct_out_dft.idft(&module, &mut ct_out, scratch.borrow()); + + ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + pt_want.rotate_inplace(&module, k); + pt_have.sub_inplace_ab(&module, &pt_want); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_in, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = div_ceil(k_ct, digits * basek); + + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.data.at_mut(0, 0)[1] = 1; + + let k: i64 = 1; + + pt_rgsw.raw_mut()[0] = 1; // X^{0} + module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | FourierGLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + ct_ggsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct, scratch.borrow()); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + pt_want.rotate_inplace(&module, k); + pt_have.sub_inplace_ab(&module, &pt_want); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); + + println!("{} {}", noise_have, noise_want); +} diff --git a/core/src/fourier_glwe/test_fft64/keyswitch.rs b/core/src/fourier_glwe/test_fft64/keyswitch.rs new file mode 100644 index 0000000..a459964 --- /dev/null +++ b/core/src/fourier_glwe/test_fft64/keyswitch.rs @@ -0,0 +1,235 @@ +use crate::{ + FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, div_ceil, + noise::log2_std_noise_gglwe_product, +}; +use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; +use sampling::source::Source; + +#[test] +fn apply() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = div_ceil(k_in, basek); + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + println!( + "test keyswitch digits: {} rank_in: {} rank_out: {}", + di, rank_in, rank_out + ); + let k_out: usize = k_ksk; // Better capture noise. + test_apply(log_n, basek, k_in, k_out, k_ksk, di, rank_in, rank_out, 3.2); + }) + }); + }); +} + +#[test] +fn apply_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 45; + let digits: usize = div_ceil(k_ct, basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!("test keyswitch_inplace digits: {} rank: {}", di, rank); + test_apply_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); + }); + }); +} + +fn test_apply( + log_n: usize, + basek: usize, + k_in: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = div_ceil(k_in, basek * digits); + + let mut ksk: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); + let mut ct_glwe_dft_in: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank_in); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); + let mut ct_glwe_dft_out: FourierGLWECiphertext, FFT64> = + FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) + | GLWECiphertext::decrypt_scratch_space(&module, basek, k_out) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_in) + | FourierGLWECiphertext::keyswitch_scratch_space( + &module, + basek, + ct_glwe_out.k(), + ksk.k(), + ct_glwe_in.k(), + digits, + rank_in, + rank_out, + ), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + + ksk.generate_from_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_in.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_in.dft(&module, &mut ct_glwe_dft_in); + ct_glwe_dft_out.keyswitch(&module, &ct_glwe_dft_in, &ksk, scratch.borrow()); + ct_glwe_dft_out.idft(&module, &mut ct_glwe_out, scratch.borrow()); + + ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_in as f64, + k_in, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = div_ceil(k_ct, basek * digits); + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) + | FourierGLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), digits, rank), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + + ksk.generate_from_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.keyswitch_inplace(&module, &ksk, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_glwe, scratch.borrow()); + + ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/fourier_glwe/test_fft64/mod.rs b/core/src/fourier_glwe/test_fft64/mod.rs new file mode 100644 index 0000000..784c37c --- /dev/null +++ b/core/src/fourier_glwe/test_fft64/mod.rs @@ -0,0 +1,2 @@ +pub mod external_product; +pub mod keyswitch; diff --git a/core/src/gglwe/automorphism.rs b/core/src/gglwe/automorphism.rs new file mode 100644 index 0000000..5460511 --- /dev/null +++ b/core/src/gglwe/automorphism.rs @@ -0,0 +1,136 @@ +use backend::{FFT64, Module, Scratch, VecZnx, VecZnxDftOps, VecZnxOps, ZnxZero}; + +use crate::{FourierGLWECiphertext, GLWEAutomorphismKey, GLWECiphertext, GetRow, Infos, ScratchCore, SetRow}; + +impl GLWEAutomorphismKey, FFT64> { + pub fn automorphism_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + let tmp_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); + let tmp_idft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let idft: usize = module.vec_znx_idft_tmp_bytes(); + let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); + tmp_dft + tmp_idft + idft + keyswitch + } + + pub fn automorphism_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank) + } +} + +impl + AsRef<[u8]>> GLWEAutomorphismKey { + pub fn automorphism, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWEAutomorphismKey, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let cols_out: usize = rhs.rank_out() + 1; + + let (mut tmp_dft, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + // Extracts relevant row + lhs.get_row(module, row_j, col_i, &mut tmp_dft); + + // Get a VecZnxBig from scratch space + let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size()); + + // Switches input outside of DFT + (0..cols_out).for_each(|i| { + module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); + }); + + // Consumes to small vec znx + let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small(); + + // Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + (0..cols_out).for_each(|i| { + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft_small_data, i); + }); + + // Wraps into ciphertext + let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: tmp_idft_small_data, + basek: self.basek(), + k: self.k(), + }; + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2); + + // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) + // and switches back to DFT domain + (0..self.rank_out() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i); + module.vec_znx_dft(1, 0, &mut tmp_dft.data, i, &tmp_idft.data, i); + }); + + // Sets back the relevant row + self.set_row(module, row_j, col_i, &tmp_dft); + }); + }); + + tmp_dft.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_dft); + }); + }); + + self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); + } + + pub fn automorphism_inplace>( + &mut self, + module: &Module, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWEAutomorphismKey = self as *mut GLWEAutomorphismKey; + self.automorphism(&module, &*self_ptr, rhs, scratch); + } + } +} diff --git a/core/src/gglwe/automorphism_key.rs b/core/src/gglwe/automorphism_key.rs new file mode 100644 index 0000000..26fea52 --- /dev/null +++ b/core/src/gglwe/automorphism_key.rs @@ -0,0 +1,83 @@ +use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module}; + +use crate::{FourierGLWECiphertext, GLWESwitchingKey, GetRow, Infos, SetRow}; + +pub struct GLWEAutomorphismKey { + pub(crate) key: GLWESwitchingKey, + pub(crate) p: i64, +} + +impl GLWEAutomorphismKey, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { + GLWEAutomorphismKey { + key: GLWESwitchingKey::alloc(module, basek, k, rows, digits, rank, rank), + p: 0, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { + GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits, rank, rank) + } +} + +impl Infos for GLWEAutomorphismKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.key.inner() + } + + fn basek(&self) -> usize { + self.key.basek() + } + + fn k(&self) -> usize { + self.key.k() + } +} + +impl GLWEAutomorphismKey { + pub fn p(&self) -> i64 { + self.p + } + + pub fn digits(&self) -> usize { + self.key.digits() + } + + pub fn rank(&self) -> usize { + self.key.rank() + } + + pub fn rank_in(&self) -> usize { + self.key.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.key.rank_out() + } +} + +impl> GetRow for GLWEAutomorphismKey { + fn get_row + AsRef<[u8]>>( + &self, + module: &Module, + row_i: usize, + col_j: usize, + res: &mut FourierGLWECiphertext, + ) { + module.mat_znx_dft_get_row(&mut res.data, &self.key.0.data, row_i, col_j); + } +} + +impl + AsRef<[u8]>> SetRow for GLWEAutomorphismKey { + fn set_row>( + &mut self, + module: &Module, + row_i: usize, + col_j: usize, + a: &FourierGLWECiphertext, + ) { + module.mat_znx_dft_set_row(&mut self.key.0.data, row_i, col_j, &a.data); + } +} diff --git a/core/src/gglwe/ciphertext.rs b/core/src/gglwe/ciphertext.rs new file mode 100644 index 0000000..a4c2f1d --- /dev/null +++ b/core/src/gglwe/ciphertext.rs @@ -0,0 +1,131 @@ +use backend::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, Module}; + +use crate::{FourierGLWECiphertext, GetRow, Infos, SetRow, div_ceil}; + +pub struct GGLWECiphertext { + pub(crate) data: MatZnxDft, + pub(crate) basek: usize, + pub(crate) k: usize, + pub(crate) digits: usize, +} + +impl GGLWECiphertext, B> { + pub fn alloc( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> Self { + let size: usize = div_ceil(k, basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + Self { + data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, size), + basek: basek, + k, + digits, + } + } + + pub fn bytes_of( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + let size: usize = div_ceil(k, basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + module.bytes_of_mat_znx_dft(rows, rank_in, rank_out + 1, rows) + } +} + +impl Infos for GGLWECiphertext { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGLWECiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } + + pub fn digits(&self) -> usize { + self.digits + } + + pub fn rank_in(&self) -> usize { + self.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.data.cols_out() - 1 + } +} + +impl> GetRow for GGLWECiphertext { + fn get_row + AsRef<[u8]>>( + &self, + module: &Module, + row_i: usize, + col_j: usize, + res: &mut FourierGLWECiphertext, + ) { + module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); + } +} + +impl + AsRef<[u8]>> SetRow for GGLWECiphertext { + fn set_row>( + &mut self, + module: &Module, + row_i: usize, + col_j: usize, + a: &FourierGLWECiphertext, + ) { + module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); + } +} diff --git a/core/src/gglwe/encryption.rs b/core/src/gglwe/encryption.rs new file mode 100644 index 0000000..7c4838b --- /dev/null +++ b/core/src/gglwe/encryption.rs @@ -0,0 +1,253 @@ +use backend::{ + FFT64, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxOps, ZnxInfos, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + FourierGLWESecret, GGLWECiphertext, GLWEAutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, GLWETensorKey, Infos, + ScratchCore, SetRow, div_ceil, +}; + +impl GGLWECiphertext, FFT64> { + pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + let size = div_ceil(k, basek); + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + + module.bytes_of_vec_znx(rank + 1, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(rank + 1, size) + } + + pub fn generate_from_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + unimplemented!() + } +} + +impl + AsRef<[u8]>> GGLWECiphertext { + pub fn encrypt_sk, DataSk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank_in(), pt.cols()); + assert_eq!(self.rank_out(), sk.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert!( + scratch.available() + >= GGLWECiphertext::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + "scratch.available: {} < GGLWECiphertext::generate_from_sk_scratch_space(module, self.rank()={}, \ + self.size()={}): {}", + scratch.available(), + self.rank(), + self.size(), + GGLWECiphertext::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + ); + assert!( + self.rows() * self.digits() * self.basek() <= self.k(), + "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}", + self.rows(), + self.digits(), + self.basek(), + self.rows() * self.digits() * self.basek(), + self.k() + ); + } + + let rows: usize = self.rows(); + let digits: usize = self.digits(); + let basek: usize = self.basek(); + let k: usize = self.k(); + let rank_in: usize = self.rank_in(); + let rank_out: usize = self.rank_out(); + + let (mut tmp_pt, scrach_1) = scratch.tmp_glwe_pt(module, basek, k); + let (mut tmp_ct, scrach_2) = scrach_1.tmp_glwe_ct(module, basek, k, rank_out); + let (mut tmp_ct_dft, scratch_3) = scrach_2.tmp_glwe_fourier(module, basek, k, rank_out); + + // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns + // + // Example for ksk rank 2 to rank 3: + // + // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) + // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) + // + // Example ksk rank 2 to rank 1 + // + // (-(a*s) + s0, a) + // (-(b*s) + s1, b) + (0..rank_in).for_each(|col_i| { + (0..rows).for_each(|row_i| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + tmp_pt.data.zero(); // zeroes for next iteration + module.vec_znx_add_scalar_inplace( + &mut tmp_pt.data, + 0, + (digits - 1) + row_i * digits, + pt, + col_i, + ); + module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3); + + // rlwe encrypt of vec_znx_pt into vec_znx_ct + tmp_ct.encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, sigma, scratch_3); + + // Switch vec_znx_ct into DFT domain + tmp_ct.dft(module, &mut tmp_ct_dft); + + // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft + self.set_row(module, row_i, col_i, &tmp_ct_dft); + }); + }); + } +} + +impl GLWESwitchingKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GGLWECiphertext::generate_from_sk_scratch_space(module, basek, k, rank) + } + + pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + GGLWECiphertext::generate_from_pk_scratch_space(module, _basek, _k, _rank) + } +} + +impl + AsRef<[u8]>> GLWESwitchingKey { + pub fn generate_from_sk, DataSkOut: AsRef<[u8]>>( + &mut self, + module: &Module, + sk_in: &GLWESecret, + sk_out: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + self.0.encrypt_sk( + module, + &sk_in.data, + sk_out, + source_xa, + source_xe, + sigma, + scratch, + ); + } +} + +impl GLWEAutomorphismKey, FFT64> { + pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GGLWECiphertext::generate_from_sk_scratch_space(module, basek, k, rank) + GLWESecret::bytes_of(module, rank) + } + + pub fn generate_from_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + GGLWECiphertext::generate_from_pk_scratch_space(module, _basek, _k, _rank) + } +} + +impl + AsRef<[u8]>> GLWEAutomorphismKey { + pub fn generate_from_sk>( + &mut self, + module: &Module, + p: i64, + sk: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.rank_out(), self.rank_in()); + assert_eq!(sk.rank(), self.rank()); + assert!( + scratch.available() + >= GLWEAutomorphismKey::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + "scratch.available(): {} < AutomorphismKey::generate_from_sk_scratch_space(module, self.rank()={}, \ + self.size()={}): {}", + scratch.available(), + self.rank(), + self.size(), + GLWEAutomorphismKey::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + ) + } + + let (mut sk_out_dft, scratch_1) = scratch.tmp_fourier_sk(module, sk.rank()); + + { + let (mut sk_out, _) = scratch_1.tmp_sk(module, sk.rank()); + (0..self.rank()).for_each(|i| { + module.scalar_znx_automorphism( + module.galois_element_inv(p), + &mut sk_out.data, + i, + &sk.data, + i, + ); + }); + sk_out_dft.set(module, &sk_out); + } + + self.key.generate_from_sk( + module, + &sk, + &sk_out_dft, + source_xa, + source_xe, + sigma, + scratch_1, + ); + + self.p = p; + } +} + +impl GLWETensorKey, FFT64> { + pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GLWESecret::bytes_of(module, 1) + + FourierGLWESecret::bytes_of(module, 1) + + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank) + } +} + +impl + AsRef<[u8]>> GLWETensorKey { + pub fn generate_from_sk>( + &mut self, + module: &Module, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + } + + let rank: usize = self.rank(); + + let (mut sk_ij, scratch1) = scratch.tmp_sk(module, 1); + let (mut sk_ij_dft, scratch2) = scratch1.tmp_fourier_sk(module, 1); + + (0..rank).for_each(|i| { + (i..rank).for_each(|j| { + module.svp_apply(&mut sk_ij_dft.data, 0, &sk.data, i, &sk.data, j); + module.scalar_znx_idft(&mut sk_ij.data, 0, &sk_ij_dft.data, 0, scratch2); + self.at_mut(i, j) + .generate_from_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch2); + }); + }) + } +} diff --git a/core/src/gglwe/external_product.rs b/core/src/gglwe/external_product.rs new file mode 100644 index 0000000..2e063ef --- /dev/null +++ b/core/src/gglwe/external_product.rs @@ -0,0 +1,162 @@ +use backend::{FFT64, Module, Scratch, ZnxZero}; + +use crate::{FourierGLWECiphertext, GGSWCiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow}; + +impl GLWESwitchingKey, FFT64> { + pub fn external_product_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); + let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ggsw: usize = FourierGLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); + tmp_in + tmp_out + ggsw + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ggsw: usize = + FourierGLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); + tmp + ggsw + } +} + +impl + AsRef<[u8]>> GLWESwitchingKey { + pub fn external_product, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank(), + "ksk_in output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.external_product(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); + } + + pub fn external_product_inplace>( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + println!("tmp: {}", tmp.size()); + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.external_product_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); + } +} + +impl GLWEAutomorphismKey, FFT64> { + pub fn external_product_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + ggsw_k: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, ggsw_k, digits, rank) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + ggsw_k: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_out, ggsw_k, digits, rank) + } +} + +impl + AsRef<[u8]>> GLWEAutomorphismKey { + pub fn external_product, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWEAutomorphismKey, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + self.key.external_product(module, &lhs.key, rhs, scratch); + } + + pub fn external_product_inplace>( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + self.key.external_product_inplace(module, rhs, scratch); + } +} diff --git a/core/src/gglwe/keyswitch.rs b/core/src/gglwe/keyswitch.rs new file mode 100644 index 0000000..632309d --- /dev/null +++ b/core/src/gglwe/keyswitch.rs @@ -0,0 +1,163 @@ +use backend::{FFT64, Module, Scratch, ZnxZero}; + +use crate::{FourierGLWECiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow}; + +impl GLWEAutomorphismKey, FFT64> { + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + } +} + +impl + AsRef<[u8]>> GLWEAutomorphismKey { + pub fn keyswitch, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWEAutomorphismKey, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + self.key.keyswitch(module, &lhs.key, rhs, scratch); + } + + pub fn keyswitch_inplace>( + &mut self, + module: &Module, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + self.key.keyswitch_inplace(module, &rhs.key, scratch); + } +} + +impl GLWESwitchingKey, FFT64> { + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank_in); + let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank_out); + let ksk: usize = + FourierGLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out); + tmp_in + tmp_out + ksk + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ksk: usize = FourierGLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); + tmp + ksk + } +} + +impl + AsRef<[u8]>> GLWESwitchingKey { + pub fn keyswitch, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.keyswitch(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); + } + + pub fn keyswitch_inplace>( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.keyswitch_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); + } +} diff --git a/core/src/gglwe/keyswitch_key.rs b/core/src/gglwe/keyswitch_key.rs new file mode 100644 index 0000000..965d596 --- /dev/null +++ b/core/src/gglwe/keyswitch_key.rs @@ -0,0 +1,91 @@ +use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module}; + +use crate::{FourierGLWECiphertext, GGLWECiphertext, GetRow, Infos, SetRow}; + +pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); + +impl GLWESwitchingKey, FFT64> { + pub fn alloc( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> Self { + GLWESwitchingKey(GGLWECiphertext::alloc( + module, basek, k, rows, digits, rank_in, rank_out, + )) + } + + pub fn bytes_of( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + GGLWECiphertext::, FFT64>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) + } +} + +impl Infos for GLWESwitchingKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + self.0.inner() + } + + fn basek(&self) -> usize { + self.0.basek() + } + + fn k(&self) -> usize { + self.0.k() + } +} + +impl GLWESwitchingKey { + pub fn rank(&self) -> usize { + self.0.data.cols_out() - 1 + } + + pub fn rank_in(&self) -> usize { + self.0.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.0.data.cols_out() - 1 + } + + pub fn digits(&self) -> usize { + self.0.digits() + } +} + +impl> GetRow for GLWESwitchingKey { + fn get_row + AsRef<[u8]>>( + &self, + module: &Module, + row_i: usize, + col_j: usize, + res: &mut FourierGLWECiphertext, + ) { + module.mat_znx_dft_get_row(&mut res.data, &self.0.data, row_i, col_j); + } +} + +impl + AsRef<[u8]>> SetRow for GLWESwitchingKey { + fn set_row>( + &mut self, + module: &Module, + row_i: usize, + col_j: usize, + a: &FourierGLWECiphertext, + ) { + module.mat_znx_dft_set_row(&mut self.0.data, row_i, col_j, &a.data); + } +} diff --git a/core/src/gglwe/mod.rs b/core/src/gglwe/mod.rs new file mode 100644 index 0000000..4c2d20a --- /dev/null +++ b/core/src/gglwe/mod.rs @@ -0,0 +1,16 @@ +pub mod automorphism; +pub mod automorphism_key; +pub mod ciphertext; +pub mod encryption; +pub mod external_product; +pub mod keyswitch; +pub mod keyswitch_key; +pub mod tensor_key; + +pub use automorphism_key::GLWEAutomorphismKey; +pub use ciphertext::GGLWECiphertext; +pub use keyswitch_key::GLWESwitchingKey; +pub use tensor_key::GLWETensorKey; + +#[cfg(test)] +mod test_fft64; diff --git a/core/src/tensor_key.rs b/core/src/gglwe/tensor_key.rs similarity index 52% rename from core/src/tensor_key.rs rename to core/src/gglwe/tensor_key.rs index c0887c9..c12c1f5 100644 --- a/core/src/tensor_key.rs +++ b/core/src/gglwe/tensor_key.rs @@ -1,13 +1,12 @@ -use backend::{Backend, FFT64, MatZnxDft, Module, ScalarZnxDftOps, Scratch}; -use sampling::source::Source; +use backend::{Backend, FFT64, MatZnxDft, Module}; -use crate::{GLWESecret, GLWESwitchingKey, Infos, ScratchCore}; +use crate::{GLWESwitchingKey, Infos}; -pub struct TensorKey { +pub struct GLWETensorKey { pub(crate) keys: Vec>, } -impl TensorKey, FFT64> { +impl GLWETensorKey, FFT64> { pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let mut keys: Vec, FFT64>> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); @@ -25,7 +24,7 @@ impl TensorKey, FFT64> { } } -impl Infos for TensorKey { +impl Infos for GLWETensorKey { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { @@ -41,7 +40,7 @@ impl Infos for TensorKey { } } -impl TensorKey { +impl GLWETensorKey { pub fn rank(&self) -> usize { self.keys[0].rank() } @@ -59,50 +58,7 @@ impl TensorKey { } } -impl TensorKey, FFT64> { - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GLWESecret::bytes_of(module, 1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank) - } -} - -impl + AsRef<[u8]>> TensorKey { - pub fn generate_from_sk>( - &mut self, - module: &Module, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); - } - - let rank: usize = self.rank(); - - let (mut sk_ij, scratch1) = scratch.tmp_sk(module, 1); - - (0..rank).for_each(|i| { - (i..rank).for_each(|j| { - module.svp_apply( - &mut sk_ij.data_fourier, - 0, - &sk.data_fourier, - i, - &sk.data_fourier, - j, - ); - module.svp_idft(&mut sk_ij.data, 0, &sk_ij.data_fourier, 0, scratch1); - self.at_mut(i, j) - .generate_from_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch1); - }); - }) - } - +impl + AsRef<[u8]>> GLWETensorKey { // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { if i > j { @@ -113,7 +69,7 @@ impl + AsRef<[u8]>> TensorKey { } } -impl> TensorKey { +impl> GLWETensorKey { // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { if i > j { diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/gglwe/test_fft64/automorphism_key.rs similarity index 76% rename from core/src/test_fft64/automorphism_key.rs rename to core/src/gglwe/test_fft64/automorphism_key.rs index 93cd5d6..c06f0be 100644 --- a/core/src/test_fft64/automorphism_key.rs +++ b/core/src/gglwe/test_fft64/automorphism_key.rs @@ -2,8 +2,8 @@ use backend::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - AutomorphismKey, FourierGLWECiphertext, GLWEPlaintext, GLWESecret, GetRow, Infos, div_ceil, - test_fft64::log2_std_noise_gglwe_product, + FourierGLWECiphertext, FourierGLWESecret, GLWEAutomorphismKey, GLWEPlaintext, GLWESecret, GetRow, Infos, div_ceil, + noise::log2_std_noise_gglwe_product, }; #[test] @@ -58,24 +58,25 @@ fn test_automorphism( let rows_in: usize = k_in / (basek * digits); let rows_apply: usize = k_in.div_ceil(basek * digits); - let mut auto_key_in: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_out: AutomorphismKey, FFT64> = - AutomorphismKey::alloc(&module, basek, k_out, rows_in, digits_in, rank); - let mut auto_key_apply: AutomorphismKey, FFT64> = - AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); + let mut auto_key_in: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_out: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_out, rows_in, digits_in, rank); + let mut auto_key_apply: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) + GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | AutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), + | GLWEAutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 auto_key_in.generate_from_sk( @@ -105,7 +106,7 @@ fn test_automorphism( let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - let mut sk_auto: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); + let mut sk_auto: GLWESecret> = GLWESecret::alloc(&module, rank); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk (0..rank).for_each(|i| { module.scalar_znx_automorphism( @@ -117,12 +118,12 @@ fn test_automorphism( ); }); - sk_auto.prep_fourier(&module); + let sk_auto_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_auto); (0..auto_key_out.rank_in()).for_each(|col_i| { (0..auto_key_out.rows()).for_each(|row_i| { auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, @@ -174,22 +175,23 @@ fn test_automorphism_inplace( let rows_in: usize = k_in / (basek * digits); let rows_apply: usize = k_in.div_ceil(basek * digits); - let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_apply: AutomorphismKey, FFT64> = - AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); + let mut auto_key: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_apply: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) + GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_in) - | AutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), + | GLWEAutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 auto_key.generate_from_sk( @@ -219,8 +221,9 @@ fn test_automorphism_inplace( let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut sk_auto: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); + let mut sk_auto: GLWESecret> = GLWESecret::alloc(&module, rank); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + (0..rank).for_each(|i| { module.scalar_znx_automorphism( module.galois_element_inv(p0 * p1), @@ -231,13 +234,13 @@ fn test_automorphism_inplace( ); }); - sk_auto.prep_fourier(&module); + let sk_auto_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_auto); (0..auto_key.rank_in()).for_each(|col_i| { (0..auto_key.rows()).for_each(|row_i| { auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, diff --git a/core/src/test_fft64/gglwe.rs b/core/src/gglwe/test_fft64/gglwe.rs similarity index 86% rename from core/src/test_fft64/gglwe.rs rename to core/src/gglwe/test_fft64/gglwe.rs index bde6308..39aad9f 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/gglwe/test_fft64/gglwe.rs @@ -2,8 +2,9 @@ use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToMut, ScratchO use sampling::source::Source; use crate::{ - FourierGLWECiphertext, GGSWCiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, div_ceil, - test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, + FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, + div_ceil, + noise::{log2_std_noise_gglwe_product, noise_ggsw_product}, }; #[test] @@ -148,16 +149,17 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, digits: usize, rank | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ksk), ); - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); ksk.generate_from_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -170,7 +172,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, digits: usize, rank (0..ksk.rank_in()).for_each(|col_i| { (0..ksk.rows()).for_each(|row_i| { ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, @@ -247,20 +249,22 @@ fn test_key_switch( ), ); - let mut sk0: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in_s0s1); - sk0.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk0: GLWESecret> = GLWESecret::alloc(&module, rank_in_s0s1); + sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out_s0s1); - sk1.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk1: GLWESecret> = GLWESecret::alloc(&module, rank_out_s0s1); + sk1.fill_ternary_prob(0.5, &mut source_xs); + let sk1_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk1); - let mut sk2: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out_s1s2); - sk2.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk2: GLWESecret> = GLWESecret::alloc(&module, rank_out_s1s2); + sk2.fill_ternary_prob(0.5, &mut source_xs); + let sk2_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 ct_gglwe_s0s1.generate_from_sk( &module, &sk0, - &sk1, + &sk1_dft, &mut source_xa, &mut source_xe, sigma, @@ -271,7 +275,7 @@ fn test_key_switch( ct_gglwe_s1s2.generate_from_sk( &module, &sk1, - &sk2, + &sk2_dft, &mut source_xa, &mut source_xe, sigma, @@ -288,7 +292,7 @@ fn test_key_switch( (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk2, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, @@ -350,20 +354,22 @@ fn test_key_switch_inplace( | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, rank_out), ); - let mut sk0: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk0.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk0: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk1.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk1: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk1.fill_ternary_prob(0.5, &mut source_xs); + let sk1_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk1); - let mut sk2: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk2.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk2: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk2.fill_ternary_prob(0.5, &mut source_xs); + let sk2_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 ct_gglwe_s0s1.generate_from_sk( &module, &sk0, - &sk1, + &sk1_dft, &mut source_xa, &mut source_xe, sigma, @@ -374,7 +380,7 @@ fn test_key_switch_inplace( ct_gglwe_s1s2.generate_from_sk( &module, &sk1, - &sk2, + &sk2_dft, &mut source_xa, &mut source_xe, sigma, @@ -392,7 +398,7 @@ fn test_key_switch_inplace( (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk2, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, @@ -464,17 +470,18 @@ fn test_external_product( pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 ct_gglwe_in.generate_from_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -484,7 +491,7 @@ fn test_external_product( ct_rgsw.encrypt_sk( &module, &pt_rgsw, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -504,7 +511,7 @@ fn test_external_product( (0..rank_in).for_each(|col_i| { (0..ct_gglwe_out.rows()).for_each(|row_i| { ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_out, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, @@ -584,17 +591,18 @@ fn test_external_product_inplace( pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 ct_gglwe.generate_from_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -604,7 +612,7 @@ fn test_external_product_inplace( ct_rgsw.encrypt_sk( &module, &pt_rgsw, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -624,7 +632,7 @@ fn test_external_product_inplace( (0..rank_in).for_each(|col_i| { (0..ct_gglwe.rows()).for_each(|row_i| { ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_out, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, diff --git a/core/src/gglwe/test_fft64/mod.rs b/core/src/gglwe/test_fft64/mod.rs new file mode 100644 index 0000000..49d23cd --- /dev/null +++ b/core/src/gglwe/test_fft64/mod.rs @@ -0,0 +1,3 @@ +pub mod automorphism_key; +pub mod gglwe; +pub mod tensor_key; diff --git a/core/src/test_fft64/tensor_key.rs b/core/src/gglwe/test_fft64/tensor_key.rs similarity index 65% rename from core/src/test_fft64/tensor_key.rs rename to core/src/gglwe/test_fft64/tensor_key.rs index e9827cb..be69625 100644 --- a/core/src/test_fft64/tensor_key.rs +++ b/core/src/gglwe/test_fft64/tensor_key.rs @@ -1,7 +1,7 @@ use backend::{FFT64, Module, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; -use crate::{FourierGLWECiphertext, GLWEPlaintext, GLWESecret, GetRow, Infos, TensorKey}; +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWEPlaintext, GLWESecret, GLWETensorKey, GetRow, Infos}; #[test] fn encrypt_sk() { @@ -17,25 +17,26 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize let rows: usize = k / basek; - let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k, rows, 1, rank); + let mut tensor_key: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k, rows, 1, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::generate_from_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::new(GLWETensorKey::generate_from_sk_scratch_space( &module, basek, tensor_key.k(), rank, )); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); tensor_key.generate_from_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -45,25 +46,19 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - let mut sk_ij = GLWESecret::alloc(&module, 1); + let mut sk_ij_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::alloc(&module, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(&module, 1); (0..rank).for_each(|i| { (0..rank).for_each(|j| { - module.svp_apply( - &mut sk_ij.data_fourier, - 0, - &sk.data_fourier, - i, - &sk.data_fourier, - j, - ); - module.svp_idft(&mut sk_ij.data, 0, &sk_ij.data_fourier, 0, scratch.borrow()); + module.svp_apply(&mut sk_ij_dft.data, 0, &sk_dft.data, i, &sk_dft.data, j); + module.scalar_znx_idft(&mut sk_ij.data, 0, &sk_ij_dft.data, 0, scratch.borrow()); (0..tensor_key.rank_in()).for_each(|col_i| { (0..tensor_key.rows()).for_each(|row_i| { tensor_key .at(i, j) .get_row(&module, row_i, col_i, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2(); assert!((sigma - std_pt).abs() <= 0.5, "{} {}", sigma, std_pt); diff --git a/core/src/ggsw.rs b/core/src/ggsw/ciphertext.rs similarity index 97% rename from core/src/ggsw.rs rename to core/src/ggsw/ciphertext.rs index 38842df..ac1fae5 100644 --- a/core/src/ggsw.rs +++ b/core/src/ggsw/ciphertext.rs @@ -6,8 +6,8 @@ use backend::{ use sampling::source::Source; use crate::{ - AutomorphismKey, FourierGLWECiphertext, GLWECiphertext, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow, - TensorKey, div_ceil, + FourierGLWECiphertext, FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWESwitchingKey, GLWETensorKey, GetRow, + Infos, ScratchCore, SetRow, div_ceil, }; pub struct GGSWCiphertext { @@ -246,7 +246,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { &mut self, module: &Module, pt: &ScalarZnx, - sk: &GLWESecret, + sk: &FourierGLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -304,7 +304,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { col_j: usize, res: &mut R, ci_dft: &VecZnxDft, - tsk: &TensorKey, + tsk: &GLWETensorKey, scratch: &mut Scratch, ) where R: VecZnxToMut, @@ -408,7 +408,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { module: &Module, lhs: &GGSWCiphertext, ksk: &GLWESwitchingKey, - tsk: &TensorKey, + tsk: &GLWETensorKey, scratch: &mut Scratch, ) { let rank: usize = self.rank(); @@ -449,7 +449,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { &mut self, module: &Module, ksk: &GLWESwitchingKey, - tsk: &TensorKey, + tsk: &GLWETensorKey, scratch: &mut Scratch, ) { unsafe { @@ -462,8 +462,8 @@ impl + AsRef<[u8]>> GGSWCiphertext { &mut self, module: &Module, lhs: &GGSWCiphertext, - auto_key: &AutomorphismKey, - tensor_key: &TensorKey, + auto_key: &GLWEAutomorphismKey, + tensor_key: &GLWETensorKey, scratch: &mut Scratch, ) { #[cfg(debug_assertions)] @@ -551,8 +551,8 @@ impl + AsRef<[u8]>> GGSWCiphertext { pub fn automorphism_inplace, DataTsk: AsRef<[u8]>>( &mut self, module: &Module, - auto_key: &AutomorphismKey, - tensor_key: &TensorKey, + auto_key: &GLWEAutomorphismKey, + tensor_key: &GLWETensorKey, scratch: &mut Scratch, ) { unsafe { diff --git a/core/src/ggsw/mod.rs b/core/src/ggsw/mod.rs new file mode 100644 index 0000000..f27b96b --- /dev/null +++ b/core/src/ggsw/mod.rs @@ -0,0 +1,6 @@ +pub mod ciphertext; + +pub use ciphertext::GGSWCiphertext; + +#[cfg(test)] +mod test_fft64; diff --git a/core/src/test_fft64/ggsw.rs b/core/src/ggsw/test_fft64/ggsw.rs similarity index 86% rename from core/src/test_fft64/ggsw.rs rename to core/src/ggsw/test_fft64/ggsw.rs index bcfb950..9219703 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/ggsw/test_fft64/ggsw.rs @@ -5,9 +5,9 @@ use backend::{ use sampling::source::Source; use crate::{ - FourierGLWECiphertext, GGSWCiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, TensorKey, - automorphism::AutomorphismKey, - test_fft64::{noise_ggsw_keyswitch, noise_ggsw_product}, + FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWEAutomorphismKey, GLWEPlaintext, GLWESecret, GLWESwitchingKey, + GLWETensorKey, GetRow, Infos, div_ceil, + noise::{noise_ggsw_keyswitch, noise_ggsw_product}, }; #[test] @@ -142,13 +142,14 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, digits: usize, rank: us | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct.encrypt_sk( &module, &pt_scalar, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -172,14 +173,14 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, digits: usize, rank: us // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -209,7 +210,7 @@ fn test_keyswitch( let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_in, rows, digits_in, rank); let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_out, rows, digits_in, rank); - let mut tsk: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut tsk: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_ksk, rows, digits, rank); let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); @@ -223,7 +224,7 @@ fn test_keyswitch( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GLWETensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_scratch_space( &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), @@ -231,16 +232,18 @@ fn test_keyswitch( let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); ksk.generate_from_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -248,7 +251,7 @@ fn test_keyswitch( ); tsk.generate_from_sk( &module, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -260,7 +263,7 @@ fn test_keyswitch( ct_in.encrypt_sk( &module, &pt_scalar, - &sk_in, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, @@ -280,14 +283,14 @@ fn test_keyswitch( // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_out.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -336,7 +339,7 @@ fn test_keyswitch_inplace( let digits_in: usize = 1; let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows, digits_in, rank); - let mut tsk: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); + let mut tsk: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); @@ -350,22 +353,24 @@ fn test_keyswitch_inplace( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GLWETensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); ksk.generate_from_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -373,7 +378,7 @@ fn test_keyswitch_inplace( ); tsk.generate_from_sk( &module, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -385,7 +390,7 @@ fn test_keyswitch_inplace( ct.encrypt_sk( &module, &pt_scalar, - &sk_in, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, @@ -411,14 +416,14 @@ fn test_keyswitch_inplace( // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_out.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -471,8 +476,8 @@ fn test_automorphism( let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_in, rows_in, digits_in, rank); let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_out, rows_in, digits_in, rank); - let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); - let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut tensor_key: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); + let mut auto_key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -484,8 +489,8 @@ fn test_automorphism( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWETensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_scratch_space( &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), @@ -493,8 +498,9 @@ fn test_automorphism( let var_xs: f64 = 0.5; - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); auto_key.generate_from_sk( &module, @@ -507,7 +513,7 @@ fn test_automorphism( ); tensor_key.generate_from_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -519,7 +525,7 @@ fn test_automorphism( ct_in.encrypt_sk( &module, &pt_scalar, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -541,14 +547,14 @@ fn test_automorphism( // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -596,8 +602,8 @@ fn test_automorphism_inplace( let digits_in: usize = 1; let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows_in, digits_in, rank); - let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); - let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut tensor_key: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); + let mut auto_key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -609,15 +615,16 @@ fn test_automorphism_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWETensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); auto_key.generate_from_sk( &module, @@ -630,7 +637,7 @@ fn test_automorphism_inplace( ); tensor_key.generate_from_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -642,7 +649,7 @@ fn test_automorphism_inplace( ct.encrypt_sk( &module, &pt_scalar, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -664,14 +671,14 @@ fn test_automorphism_inplace( // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -742,13 +749,14 @@ fn test_external_product( | GGSWCiphertext::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw_rhs.encrypt_sk( &module, &pt_ggsw_rhs, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -758,7 +766,7 @@ fn test_external_product( ct_ggsw_lhs_in.encrypt_sk( &module, &pt_ggsw_lhs, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -787,13 +795,13 @@ fn test_external_product( if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt.data, 0, &pt_want.data, 0); @@ -862,13 +870,14 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw | GGSWCiphertext::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw_rhs.encrypt_sk( &module, &pt_ggsw_rhs, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -878,7 +887,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw ct_ggsw_lhs.encrypt_sk( &module, &pt_ggsw_lhs, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -907,13 +916,13 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt.data, 0, &pt_want.data, 0); diff --git a/core/src/ggsw/test_fft64/mod.rs b/core/src/ggsw/test_fft64/mod.rs new file mode 100644 index 0000000..3326f10 --- /dev/null +++ b/core/src/ggsw/test_fft64/mod.rs @@ -0,0 +1 @@ +mod ggsw; diff --git a/core/src/glwe/automorphism.rs b/core/src/glwe/automorphism.rs index a4165aa..1513362 100644 --- a/core/src/glwe/automorphism.rs +++ b/core/src/glwe/automorphism.rs @@ -1,6 +1,6 @@ use backend::{FFT64, Module, Scratch, VecZnxOps}; -use crate::{AutomorphismKey, GLWECiphertext}; +use crate::{GLWEAutomorphismKey, GLWECiphertext}; impl GLWECiphertext> { pub fn automorphism_scratch_space( @@ -32,7 +32,7 @@ impl + AsMut<[u8]>> GLWECiphertext { &mut self, module: &Module, lhs: &GLWECiphertext, - rhs: &AutomorphismKey, + rhs: &GLWEAutomorphismKey, scratch: &mut Scratch, ) { self.keyswitch(module, lhs, &rhs.key, scratch); @@ -44,7 +44,7 @@ impl + AsMut<[u8]>> GLWECiphertext { pub fn automorphism_inplace>( &mut self, module: &Module, - rhs: &AutomorphismKey, + rhs: &GLWEAutomorphismKey, scratch: &mut Scratch, ) { self.keyswitch_inplace(module, &rhs.key, scratch); @@ -57,7 +57,7 @@ impl + AsMut<[u8]>> GLWECiphertext { &mut self, module: &Module, lhs: &GLWECiphertext, - rhs: &AutomorphismKey, + rhs: &GLWEAutomorphismKey, scratch: &mut Scratch, ) { Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, lhs, &rhs.key, scratch); @@ -66,7 +66,7 @@ impl + AsMut<[u8]>> GLWECiphertext { pub fn automorphism_add_inplace>( &mut self, module: &Module, - rhs: &AutomorphismKey, + rhs: &GLWEAutomorphismKey, scratch: &mut Scratch, ) { unsafe { @@ -79,7 +79,7 @@ impl + AsMut<[u8]>> GLWECiphertext { &mut self, module: &Module, lhs: &GLWECiphertext, - rhs: &AutomorphismKey, + rhs: &GLWEAutomorphismKey, scratch: &mut Scratch, ) { Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, lhs, &rhs.key, scratch); @@ -88,7 +88,7 @@ impl + AsMut<[u8]>> GLWECiphertext { pub fn automorphism_sub_ab_inplace>( &mut self, module: &Module, - rhs: &AutomorphismKey, + rhs: &GLWEAutomorphismKey, scratch: &mut Scratch, ) { unsafe { @@ -101,7 +101,7 @@ impl + AsMut<[u8]>> GLWECiphertext { &mut self, module: &Module, lhs: &GLWECiphertext, - rhs: &AutomorphismKey, + rhs: &GLWEAutomorphismKey, scratch: &mut Scratch, ) { Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, lhs, &rhs.key, scratch); @@ -110,7 +110,7 @@ impl + AsMut<[u8]>> GLWECiphertext { pub fn automorphism_sub_ba_inplace>( &mut self, module: &Module, - rhs: &AutomorphismKey, + rhs: &GLWEAutomorphismKey, scratch: &mut Scratch, ) { unsafe { diff --git a/core/src/glwe/decryption.rs b/core/src/glwe/decryption.rs index dd6428d..eac91d6 100644 --- a/core/src/glwe/decryption.rs +++ b/core/src/glwe/decryption.rs @@ -1,6 +1,6 @@ use backend::{FFT64, Module, ScalarZnxDftOps, Scratch, VecZnxBigOps, VecZnxDftOps, ZnxZero}; -use crate::{GLWECiphertext, GLWEPlaintext, GLWESecret, Infos}; +use crate::{FourierGLWESecret, GLWECiphertext, GLWEPlaintext, Infos}; impl> GLWECiphertext { pub fn clone(&self) -> GLWECiphertext> { @@ -15,7 +15,7 @@ impl> GLWECiphertext { &self, module: &Module, pt: &mut GLWEPlaintext, - sk: &GLWESecret, + sk: &FourierGLWESecret, scratch: &mut Scratch, ) { #[cfg(debug_assertions)] @@ -36,7 +36,7 @@ impl> GLWECiphertext { // ci_dft = DFT(a[i]) * DFT(s[i]) let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); - module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); + module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1); let ci_big = module.vec_znx_idft_consume(ci_dft); // c0_big += a[i] * s[i] diff --git a/core/src/glwe/encryption.rs b/core/src/glwe/encryption.rs index 1910f98..3b70d99 100644 --- a/core/src/glwe/encryption.rs +++ b/core/src/glwe/encryption.rs @@ -4,7 +4,7 @@ use backend::{ }; use sampling::source::Source; -use crate::{GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, SIX_SIGMA, div_ceil, keys::SecretDistribution}; +use crate::{FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, Infos, SIX_SIGMA, dist::Distribution, div_ceil}; impl GLWECiphertext> { pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { @@ -24,7 +24,7 @@ impl + AsMut<[u8]>> GLWECiphertext { &mut self, module: &Module, pt: &GLWEPlaintext, - sk: &GLWESecret, + sk: &FourierGLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -44,7 +44,7 @@ impl + AsMut<[u8]>> GLWECiphertext { pub fn encrypt_zero_sk>( &mut self, module: &Module, - sk: &GLWESecret, + sk: &FourierGLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -106,7 +106,7 @@ impl + AsMut<[u8]>> GLWECiphertext { &mut self, module: &Module, pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecret, + sk: &FourierGLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -148,7 +148,7 @@ impl + AsMut<[u8]>> GLWECiphertext { // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); - module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); + module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1); let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); // use c[0] as buffer, which is overwritten later by the normalization step @@ -213,16 +213,16 @@ impl + AsMut<[u8]>> GLWECiphertext { { let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); match pk.dist { - SecretDistribution::NONE => panic!( + Distribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ Self::generate" ), - SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - SecretDistribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu), - SecretDistribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu), - SecretDistribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu), - SecretDistribution::ZERO => {} + Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu), + Distribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu), + Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu), + Distribution::ZERO => {} } module.svp_prepare(&mut u_dft, 0, &u, 0); diff --git a/core/src/glwe/mod.rs b/core/src/glwe/mod.rs index 7f0c7aa..e3879cd 100644 --- a/core/src/glwe/mod.rs +++ b/core/src/glwe/mod.rs @@ -11,24 +11,13 @@ pub mod public_key; pub mod secret; pub mod trace; -#[allow(unused_imports)] -pub use automorphism::*; -pub use ciphertext::*; -#[allow(unused_imports)] -pub use decryption::*; -#[allow(unused_imports)] -pub use encryption::*; -#[allow(unused_imports)] -pub use external_product::*; -#[allow(unused_imports)] -pub use keyswitch::*; -pub use ops::*; -pub use packing::*; -pub use plaintext::*; -pub use public_key::*; -pub use secret::*; -#[allow(unused_imports)] -pub use trace::*; +pub use ciphertext::GLWECiphertext; +pub(crate) use ciphertext::{GLWECiphertextToMut, GLWECiphertextToRef}; +pub use ops::GLWEOps; +pub use packing::GLWEPacker; +pub use plaintext::GLWEPlaintext; +pub use public_key::GLWEPublicKey; +pub use secret::GLWESecret; #[cfg(test)] mod test_fft64; diff --git a/core/src/glwe/packing.rs b/core/src/glwe/packing.rs index 85aceb6..3496994 100644 --- a/core/src/glwe/packing.rs +++ b/core/src/glwe/packing.rs @@ -1,4 +1,4 @@ -use crate::{AutomorphismKey, GLWECiphertext, GLWEOps, Infos, ScratchCore}; +use crate::{GLWEAutomorphismKey, GLWECiphertext, GLWEOps, Infos, ScratchCore}; use std::collections::HashMap; use backend::{FFT64, Module, Scratch}; @@ -7,7 +7,7 @@ use backend::{FFT64, Module, Scratch}; /// with constant memory of Log(N) ciphertexts. /// Main difference with usual GLWE packing is that /// the output is bit-reversed. -pub struct StreamPacker { +pub struct GLWEPacker { accumulators: Vec, log_batch: usize, counter: usize, @@ -39,7 +39,7 @@ impl Accumulator { } } -impl StreamPacker { +impl GLWEPacker { /// Instantiates a new [StreamPacker]. /// /// #Arguments @@ -98,7 +98,7 @@ impl StreamPacker { module: &Module, res: &mut Vec>>, a: Option<&GLWECiphertext>, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { pack_core( @@ -125,7 +125,7 @@ impl StreamPacker { &mut self, module: &Module, res: &mut Vec>>, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { if self.counter != 0 { @@ -151,7 +151,7 @@ fn pack_core, DataAK: AsRef<[u8]>>( a: Option<&GLWECiphertext>, accumulators: &mut [Accumulator], i: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { let log_n: usize = module.log_n(); @@ -215,7 +215,7 @@ fn combine, DataAK: AsRef<[u8]>>( acc: &mut Accumulator, b: Option<&GLWECiphertext>, i: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { let log_n: usize = module.log_n(); diff --git a/core/src/glwe/plaintext.rs b/core/src/glwe/plaintext.rs index c1e9175..9f24be0 100644 --- a/core/src/glwe/plaintext.rs +++ b/core/src/glwe/plaintext.rs @@ -1,10 +1,6 @@ use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; -use crate::{ - GLWEOps, Infos, SetMetaData, - ciphertext::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef}, - div_ceil, -}; +use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData, div_ceil}; pub struct GLWEPlaintext { pub data: VecZnx, diff --git a/core/src/glwe/public_key.rs b/core/src/glwe/public_key.rs index 4a1ed15..f4871ad 100644 --- a/core/src/glwe/public_key.rs +++ b/core/src/glwe/public_key.rs @@ -1,18 +1,18 @@ use backend::{Backend, FFT64, Module, ScratchOwned, VecZnxDft}; use sampling::source::Source; -use crate::{FourierGLWECiphertext, GLWESecret, Infos, keys::SecretDistribution}; +use crate::{FourierGLWECiphertext, FourierGLWESecret, Infos, dist::Distribution}; pub struct GLWEPublicKey { pub(crate) data: FourierGLWECiphertext, - pub(crate) dist: SecretDistribution, + pub(crate) dist: Distribution, } impl GLWEPublicKey, B> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { data: FourierGLWECiphertext::alloc(module, basek, k, rank), - dist: SecretDistribution::NONE, + dist: Distribution::NONE, } } @@ -47,7 +47,7 @@ impl + AsMut<[u8]>> GLWEPublicKey { pub fn generate_from_sk>( &mut self, module: &Module, - sk: &GLWESecret, + sk: &FourierGLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -55,7 +55,7 @@ impl + AsMut<[u8]>> GLWEPublicKey { #[cfg(debug_assertions)] { match sk.dist { - SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), + Distribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), _ => {} } } diff --git a/core/src/glwe/secret.rs b/core/src/glwe/secret.rs index b704365..5073d2b 100644 --- a/core/src/glwe/secret.rs +++ b/core/src/glwe/secret.rs @@ -1,31 +1,27 @@ -use backend::{ - Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ZnxInfos, ZnxZero, -}; +use backend::{Backend, Module, ScalarZnx, ScalarZnxAlloc, ZnxInfos, ZnxZero}; use sampling::source::Source; -use crate::keys::SecretDistribution; +use crate::dist::Distribution; -pub struct GLWESecret { +pub struct GLWESecret { pub(crate) data: ScalarZnx, - pub(crate) data_fourier: ScalarZnxDft, - pub(crate) dist: SecretDistribution, + pub(crate) dist: Distribution, } -impl GLWESecret, B> { - pub fn alloc(module: &Module, rank: usize) -> Self { +impl GLWESecret> { + pub fn alloc(module: &Module, rank: usize) -> Self { Self { data: module.new_scalar_znx(rank), - data_fourier: module.new_scalar_znx_dft(rank), - dist: SecretDistribution::NONE, + dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, rank: usize) -> usize { - module.bytes_of_scalar_znx(rank) + module.bytes_of_scalar_znx_dft(rank) + pub fn bytes_of(module: &Module, rank: usize) -> usize { + module.bytes_of_scalar_znx(rank) } } -impl GLWESecret { +impl GLWESecret { pub fn n(&self) -> usize { self.data.n() } @@ -39,55 +35,50 @@ impl GLWESecret { } } -impl + AsRef<[u8]>> GLWESecret { - pub fn fill_ternary_prob(&mut self, module: &Module, prob: f64, source: &mut Source) { +impl + AsRef<[u8]>> GLWESecret { + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { (0..self.rank()).for_each(|i| { self.data.fill_ternary_prob(i, prob, source); }); - self.prep_fourier(module); - self.dist = SecretDistribution::TernaryProb(prob); + self.dist = Distribution::TernaryProb(prob); } - pub fn fill_ternary_hw(&mut self, module: &Module, hw: usize, source: &mut Source) { + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { (0..self.rank()).for_each(|i| { self.data.fill_ternary_hw(i, hw, source); }); - self.prep_fourier(module); - self.dist = SecretDistribution::TernaryFixed(hw); + self.dist = Distribution::TernaryFixed(hw); } - pub fn fill_binary_prob(&mut self, module: &Module, prob: f64, source: &mut Source) { + pub fn fill_binary_prob(&mut self, prob: f64, source: &mut Source) { (0..self.rank()).for_each(|i| { self.data.fill_binary_prob(i, prob, source); }); - self.prep_fourier(module); - self.dist = SecretDistribution::BinaryProb(prob); + self.dist = Distribution::BinaryProb(prob); } - pub fn fill_binary_hw(&mut self, module: &Module, hw: usize, source: &mut Source) { + pub fn fill_binary_hw(&mut self, hw: usize, source: &mut Source) { (0..self.rank()).for_each(|i| { self.data.fill_binary_hw(i, hw, source); }); - self.prep_fourier(module); - self.dist = SecretDistribution::BinaryFixed(hw); + self.dist = Distribution::BinaryFixed(hw); } - pub fn fill_binary_block(&mut self, module: &Module, block_size: usize, source: &mut Source) { + pub fn fill_binary_block(&mut self, block_size: usize, source: &mut Source) { (0..self.rank()).for_each(|i| { self.data.fill_binary_block(i, block_size, source); }); - self.prep_fourier(module); - self.dist = SecretDistribution::BinaryBlock(block_size); + self.dist = Distribution::BinaryBlock(block_size); } pub fn fill_zero(&mut self) { self.data.zero(); - self.dist = SecretDistribution::ZERO; + self.dist = Distribution::ZERO; } - pub(crate) fn prep_fourier(&mut self, module: &Module) { - (0..self.rank()).for_each(|i| { - module.svp_prepare(&mut self.data_fourier, i, &self.data, i); - }); - } + // pub(crate) fn prep_fourier(&mut self, module: &Module) { + // (0..self.rank()).for_each(|i| { + // module.svp_prepare(&mut self.data_fourier, i, &self.data, i); + // }); + // } } diff --git a/core/src/glwe/test_fft64/automorphism.rs b/core/src/glwe/test_fft64/automorphism.rs index 8b13789..963083b 100644 --- a/core/src/glwe/test_fft64/automorphism.rs +++ b/core/src/glwe/test_fft64/automorphism.rs @@ -1 +1,224 @@ +use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; +use sampling::source::Source; + +use crate::{ + FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, div_ceil, + noise::log2_std_noise_gglwe_product, +}; + +#[test] +fn apply_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = div_ceil(k_ct, basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!("test automorphism_inplace digits: {} rank: {}", di, rank); + test_automorphism_inplace(log_n, basek, -5, k_ct, k_ksk, di, rank, 3.2); + }); + }); +} + +#[test] +fn apply() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = div_ceil(k_in, basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_out: usize = k_ksk; // Better capture noise. + println!("test automorphism digits: {} rank: {}", di, rank); + test_automorphism(log_n, basek, -5, k_out, k_in, k_ksk, di, rank, 3.2); + }) + }); +} + +fn test_automorphism( + log_n: usize, + basek: usize, + p: i64, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = div_ceil(k_in, basek * digits); + + let mut autokey: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + | GLWECiphertext::automorphism_scratch_space( + &module, + basek, + ct_out.k(), + ct_in.k(), + autokey.k(), + digits, + rank, + ), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + autokey.generate_from_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); + ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + println!("{}", noise_have); + + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_in, + k_ksk, + ); + + assert!( + noise_have <= noise_want + 1.0, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_automorphism_inplace( + log_n: usize, + basek: usize, + p: i64, + k_ct: usize, + k_ksk: usize, + digits: usize, + rank: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = div_ceil(k_ct, basek * digits); + + let mut autokey: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::automorphism_inplace_scratch_space(&module, basek, ct.k(), autokey.k(), digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + autokey.generate_from_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.automorphism_inplace(&module, &autokey, scratch.borrow()); + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/glwe/test_fft64/encryption.rs b/core/src/glwe/test_fft64/encryption.rs index 06490c5..2d82575 100644 --- a/core/src/glwe/test_fft64/encryption.rs +++ b/core/src/glwe/test_fft64/encryption.rs @@ -2,7 +2,7 @@ use backend::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, use itertools::izip; use sampling::source::Source; -use crate::{FourierGLWECiphertext, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos}; +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos}; #[test] fn encrypt_sk() { @@ -46,8 +46,9 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); let mut data_want: Vec = vec![0i64; module.n()]; @@ -60,7 +61,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: ct.encrypt_sk( &module, &pt, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -69,7 +70,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: pt.data.zero(); - ct.decrypt(&module, &mut pt, &sk, scratch.borrow()); + ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); let mut data_have: Vec = vec![0i64; module.n()]; @@ -98,8 +99,9 @@ fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, ran let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); let mut ct_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); @@ -110,13 +112,13 @@ fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, ran ct_dft.encrypt_zero_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_dft.decrypt(&module, &mut pt, &sk, scratch.borrow()); + ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); } @@ -132,11 +134,12 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xu: Source = Source::new([0u8; 32]); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::alloc(&module, basek, k_pk, rank); - pk.generate_from_sk(&module, &sk, &mut source_xa, &mut source_xe, sigma); + pk.generate_from_sk(&module, &sk_dft, &mut source_xa, &mut source_xe, sigma); let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) @@ -164,7 +167,7 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); diff --git a/core/src/glwe/test_fft64/external_product.rs b/core/src/glwe/test_fft64/external_product.rs index 7a32343..4ba77c3 100644 --- a/core/src/glwe/test_fft64/external_product.rs +++ b/core/src/glwe/test_fft64/external_product.rs @@ -1,86 +1,16 @@ -use backend::{ - Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut, - ZnxZero, -}; -use itertools::izip; +use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; use sampling::source::Source; use crate::{ - FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, - automorphism::AutomorphismKey, - keyswitch_key::GLWESwitchingKey, - test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, + FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, div_ceil, noise::noise_ggsw_product, }; #[test] -fn encrypt_sk() { - let log_n: usize = 8; - (1..4).for_each(|rank| { - println!("test encrypt_sk rank: {}", rank); - test_encrypt_sk(log_n, 8, 54, 30, 3.2, rank); - }); -} - -#[test] -fn encrypt_zero_sk() { - let log_n: usize = 8; - (1..4).for_each(|rank| { - println!("test encrypt_zero_sk rank: {}", rank); - test_encrypt_zero_sk(log_n, 8, 64, 3.2, rank); - }); -} - -#[test] -fn encrypt_pk() { - let log_n: usize = 8; - (1..4).for_each(|rank| { - println!("test encrypt_pk rank: {}", rank); - test_encrypt_pk(log_n, 8, 64, 64, 3.2, rank) - }); -} - -#[test] -fn keyswitch() { +fn apply() { let log_n: usize = 8; let basek: usize = 12; let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_out: usize = k_ksk; // better capture noise - println!( - "test keyswitch digits: {} rank_in: {} rank_out: {}", - di, rank_in, rank_out - ); - test_keyswitch(log_n, basek, k_out, k_in, k_ksk, di, rank_in, rank_out, 3.2); - }) - }); - }); -} - -#[test] -fn keyswitch_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 45; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!("test keyswitch_inplace digits: {} rank: {}", di, rank); - test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); - }); - }); -} - -#[test] -fn external_product() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank| { (1..digits + 1).for_each(|di| { let k_ggsw: usize = k_in + basek * di; @@ -92,7 +22,7 @@ fn external_product() { } #[test] -fn external_product_inplace() { +fn apply_inplace() { let log_n: usize = 8; let basek: usize = 12; let k_ct: usize = 60; @@ -106,548 +36,6 @@ fn external_product_inplace() { }); } -#[test] -fn automorphism_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!("test automorphism_inplace digits: {} rank: {}", di, rank); - test_automorphism_inplace(log_n, basek, -5, k_ct, k_ksk, di, rank, 3.2); - }); - }); -} - -#[test] -fn automorphism() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_out: usize = k_ksk; // Better capture noise. - println!("test automorphism digits: {} rank: {}", di, rank); - test_automorphism(log_n, basek, -5, k_out, k_in, k_ksk, di, rank, 3.2); - }) - }); -} - -fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); - - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_pt); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0xFF); - - pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10); - - ct.encrypt_sk( - &module, - &pt, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - pt.data.zero(); - - ct.decrypt(&module, &mut pt, &sk, scratch.borrow()); - - let mut data_have: Vec = vec![0i64; module.n()]; - - pt.data - .decode_vec_i64(0, basek, pt.size() * basek, &mut data_have); - - // TODO: properly assert the decryption noise through std(dec(ct) - pt) - let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64; - izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { - let b_scaled = (*b as f64) / scale; - assert!( - (*a as f64 - b_scaled).abs() < 0.1, - "{} {}", - *a as f64, - b_scaled - ) - }); -} - -fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([1u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut ct_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); - - let mut scratch: ScratchOwned = ScratchOwned::new( - FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | FourierGLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank), - ); - - ct_dft.encrypt_zero_sk( - &module, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - ct_dft.decrypt(&module, &mut pt, &sk, scratch.borrow()); - - assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); -} - -fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); - - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - let mut source_xu: Source = Source::new([0u8; 32]); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::alloc(&module, basek, k_pk, rank); - pk.generate_from_sk(&module, &sk, &mut source_xa, &mut source_xe, sigma); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::encrypt_pk_scratch_space(&module, basek, pk.k()), - ); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0); - - pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10); - - ct.encrypt_pk( - &module, - &pt_want, - &pk, - &mut source_xu, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); - - let noise_have: f64 = pt_want.data.std(0, basek).log2(); - let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); - - assert!( - (noise_have - noise_want).abs() < 0.2, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_keyswitch( - log_n: usize, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ksk: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_out) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | GLWECiphertext::keyswitch_scratch_space( - &module, - basek, - ct_out.k(), - ct_in.k(), - ksk.k(), - digits, - rank_in, - rank_out, - ), - ); - - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ksk.generate_from_sk( - &module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.encrypt_sk( - &module, - &pt_want, - &sk_in, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow()); - ct_out.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank_in as f64, - k_in, - k_ksk, - ); - - println!("{} vs. {}", noise_have, noise_want); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ct_grlwe.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_grlwe.k(), digits, rank), - ); - - let mut sk0: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk0.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut sk1: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk1.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ct_grlwe.generate_from_sk( - &module, - &sk0, - &sk1, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.encrypt_sk( - &module, - &pt_want, - &sk0, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - - ct_glwe.decrypt(&module, &mut pt_have, &sk1, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_ct, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_automorphism( - log_n: usize, - basek: usize, - p: i64, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | GLWECiphertext::automorphism_scratch_space( - &module, - basek, - ct_out.k(), - ct_in.k(), - autokey.k(), - digits, - rank, - ), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - autokey.generate_from_sk( - &module, - p, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); - ct_out.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - println!("{}", noise_have); - - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_in, - k_ksk, - ); - - assert!( - noise_have <= noise_want + 1.0, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_automorphism_inplace( - log_n: usize, - basek: usize, - p: i64, - k_ct: usize, - k_ksk: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::automorphism_inplace_scratch_space(&module, basek, ct.k(), autokey.k(), digits, rank), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - autokey.generate_from_sk( - &module, - p, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.automorphism_inplace(&module, &autokey, scratch.borrow()); - ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_ct, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - fn test_external_product( log_n: usize, basek: usize, @@ -699,13 +87,14 @@ fn test_external_product( ), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw.encrypt_sk( &module, &pt_rgsw, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -715,7 +104,7 @@ fn test_external_product( ct_glwe_in.encrypt_sk( &module, &pt_want, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -724,7 +113,7 @@ fn test_external_product( ct_glwe_out.external_product(&module, &ct_glwe_in, &ct_ggsw, scratch.borrow()); - ct_glwe_out.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct_glwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); @@ -793,13 +182,14 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw.encrypt_sk( &module, &pt_rgsw, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -809,7 +199,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw ct_glwe.encrypt_sk( &module, &pt_want, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -818,7 +208,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw ct_glwe.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); - ct_glwe.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct_glwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); diff --git a/core/src/glwe/test_fft64/keyswitch.rs b/core/src/glwe/test_fft64/keyswitch.rs index 8b13789..f80ae86 100644 --- a/core/src/glwe/test_fft64/keyswitch.rs +++ b/core/src/glwe/test_fft64/keyswitch.rs @@ -1 +1,227 @@ +use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; +use sampling::source::Source; +use crate::{ + FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, div_ceil, + noise::log2_std_noise_gglwe_product, +}; + +#[test] +fn apply() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = div_ceil(k_in, basek); + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_out: usize = k_ksk; // better capture noise + println!( + "test keyswitch digits: {} rank_in: {} rank_out: {}", + di, rank_in, rank_out + ); + test_keyswitch(log_n, basek, k_out, k_in, k_ksk, di, rank_in, rank_out, 3.2); + }) + }); + }); +} + +#[test] +fn apply_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 45; + let digits: usize = div_ceil(k_ct, basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!("test keyswitch_inplace digits: {} rank: {}", di, rank); + test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); + }); + }); +} + +fn test_keyswitch( + log_n: usize, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = div_ceil(k_in, basek * digits); + + let mut ksk: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_out) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + | GLWECiphertext::keyswitch_scratch_space( + &module, + basek, + ct_out.k(), + ct_in.k(), + ksk.k(), + digits, + rank_in, + rank_out, + ), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + + ksk.generate_from_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow()); + ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_in as f64, + k_in, + k_ksk, + ); + + println!("{} vs. {}", noise_have, noise_want); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = div_ceil(k_ct, basek * digits); + + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ct_grlwe.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_grlwe.k(), digits, rank), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + + ct_grlwe.generate_from_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); + + ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/glwe/test_fft64/packing.rs b/core/src/glwe/test_fft64/packing.rs index 107461c..da817e6 100644 --- a/core/src/glwe/test_fft64/packing.rs +++ b/core/src/glwe/test_fft64/packing.rs @@ -1,11 +1,11 @@ -use crate::{AutomorphismKey, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, StreamPacker}; +use crate::{FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWESecret, div_ceil}; use std::collections::HashMap; use backend::{Encoding, FFT64, Module, ScratchOwned, Stats}; use sampling::source::Source; #[test] -fn packing() { +fn apply() { let log_n: usize = 5; let module: Module = Module::::new(1 << log_n); @@ -26,12 +26,13 @@ fn packing() { let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct) | GLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | StreamPacker::scratch_space(&module, basek, k_ct, k_ksk, digits, rank), + | GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWEPacker::scratch_space(&module, basek, k_ct, k_ksk, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut data: Vec = vec![0i64; module.n()]; @@ -40,11 +41,11 @@ fn packing() { }); pt.data.encode_vec_i64(0, basek, pt_k, &data, 32); - let gal_els: Vec = StreamPacker::galois_elements(&module); + let gal_els: Vec = GLWEPacker::galois_elements(&module); - let mut auto_keys: HashMap, FFT64>> = HashMap::new(); + let mut auto_keys: HashMap, FFT64>> = HashMap::new(); gal_els.iter().for_each(|gal_el| { - let mut key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); key.generate_from_sk( &module, *gal_el, @@ -59,14 +60,14 @@ fn packing() { let log_batch: usize = 0; - let mut packer: StreamPacker = StreamPacker::new(&module, log_batch, basek, k_ct, rank); + let mut packer: GLWEPacker = GLWEPacker::new(&module, log_batch, basek, k_ct, rank); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); ct.encrypt_sk( &module, &pt, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -79,7 +80,7 @@ fn packing() { ct.encrypt_sk( &module, &pt, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -115,7 +116,7 @@ fn packing() { }); pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32); - res_i.decrypt(&module, &mut pt, &sk, scratch.borrow()); + res_i.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); if i & 1 == 0 { pt.sub_inplace_ab(&module, &pt_want); diff --git a/core/src/glwe/test_fft64/trace.rs b/core/src/glwe/test_fft64/trace.rs index 885aa90..eae2bce 100644 --- a/core/src/glwe/test_fft64/trace.rs +++ b/core/src/glwe/test_fft64/trace.rs @@ -3,10 +3,13 @@ use std::collections::HashMap; use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps, ZnxView, ZnxViewMut}; use sampling::source::Source; -use crate::{AutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, test_fft64::var_noise_gglwe_product}; +use crate::{ + FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, div_ceil, + noise::var_noise_gglwe_product, +}; #[test] -fn trace_inplace() { +fn apply_inplace() { let log_n: usize = 8; (1..4).for_each(|rank| { println!("test trace_inplace rank: {}", rank); @@ -33,12 +36,13 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_autokey, rank) + | GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_autokey, rank) | GLWECiphertext::trace_inplace_scratch_space(&module, basek, ct.k(), k_autokey, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); let mut data_want: Vec = vec![0i64; module.n()]; @@ -53,17 +57,18 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us ct.encrypt_sk( &module, &pt_have, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - let mut auto_keys: HashMap, FFT64>> = HashMap::new(); + let mut auto_keys: HashMap, FFT64>> = HashMap::new(); let gal_els: Vec = GLWECiphertext::trace_galois_elements(&module); gal_els.iter().for_each(|gal_el| { - let mut key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_autokey, rows, digits, rank); + let mut key: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_autokey, rows, digits, rank); key.generate_from_sk( &module, *gal_el, @@ -81,7 +86,7 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us (0..pt_want.size()).for_each(|i| pt_want.data.at_mut(0, i)[0] = pt_have.data.at(0, i)[0]); - ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); module.vec_znx_normalize_inplace(basek, &mut pt_want.data, 0, scratch.borrow()); diff --git a/core/src/glwe/trace.rs b/core/src/glwe/trace.rs index 3c6a5bb..c702489 100644 --- a/core/src/glwe/trace.rs +++ b/core/src/glwe/trace.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use backend::{FFT64, Module, Scratch}; -use crate::{AutomorphismKey, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData}; +use crate::{GLWEAutomorphismKey, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData}; impl GLWECiphertext> { pub fn trace_galois_elements(module: &Module) -> Vec { @@ -51,7 +51,7 @@ where start: usize, end: usize, lhs: &GLWECiphertext, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where GLWECiphertext: GLWECiphertextToRef + Infos, @@ -65,7 +65,7 @@ where module: &Module, start: usize, end: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { (start..end).for_each(|i| { diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs deleted file mode 100644 index 201ae22..0000000 --- a/core/src/keyswitch_key.rs +++ /dev/null @@ -1,343 +0,0 @@ -use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, Scratch, ZnxZero}; -use sampling::source::Source; - -use crate::{FourierGLWECiphertext, GGLWECiphertext, GGSWCiphertext, GLWESecret, GetRow, Infos, ScratchCore, SetRow}; - -pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); - -impl GLWESwitchingKey, FFT64> { - pub fn alloc( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self { - GLWESwitchingKey(GGLWECiphertext::alloc( - module, basek, k, rows, digits, rank_in, rank_out, - )) - } - - pub fn bytes_of( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - GGLWECiphertext::, FFT64>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) - } -} - -impl Infos for GLWESwitchingKey { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - self.0.inner() - } - - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { - self.0.k() - } -} - -impl GLWESwitchingKey { - pub fn rank(&self) -> usize { - self.0.data.cols_out() - 1 - } - - pub fn rank_in(&self) -> usize { - self.0.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.0.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.0.digits() - } -} - -impl> GetRow for GLWESwitchingKey { - fn get_row + AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut FourierGLWECiphertext, - ) { - module.mat_znx_dft_get_row(&mut res.data, &self.0.data, row_i, col_j); - } -} - -impl + AsRef<[u8]>> SetRow for GLWESwitchingKey { - fn set_row>( - &mut self, - module: &Module, - row_i: usize, - col_j: usize, - a: &FourierGLWECiphertext, - ) { - module.mat_znx_dft_set_row(&mut self.0.data, row_i, col_j, &a.data); - } -} - -impl GLWESwitchingKey, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GGLWECiphertext::generate_from_sk_scratch_space(module, basek, k, rank) - } - - pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { - GGLWECiphertext::generate_from_pk_scratch_space(module, _basek, _k, _rank) - } - - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank_in); - let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank_out); - let ksk: usize = - FourierGLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out); - tmp_in + tmp_out + ksk - } - - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let ksk: usize = FourierGLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); - tmp + ksk - } - - pub fn external_product_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); - let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let ggsw: usize = FourierGLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); - tmp_in + tmp_out + ggsw - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let ggsw: usize = - FourierGLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); - tmp + ggsw - } -} -impl + AsRef<[u8]>> GLWESwitchingKey { - pub fn generate_from_sk, DataSkOut: AsRef<[u8]>>( - &mut self, - module: &Module, - sk_in: &GLWESecret, - sk_out: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.0.encrypt_sk( - module, - &sk_in.data, - sk_out, - source_xa, - source_xe, - sigma, - scratch, - ); - } - - pub fn keyswitch, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWESwitchingKey, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - lhs.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - } - - let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_in); - tmp_out.keyswitch(module, &tmp_in, rhs, scratch2); - self.set_row(module, row_j, col_i, &tmp_out); - }); - }); - - tmp_out.data.zero(); - - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank_in()).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_out); - }); - }); - } - - pub fn keyswitch_inplace>( - &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - } - - let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - self.get_row(module, row_j, col_i, &mut tmp); - tmp.keyswitch_inplace(module, rhs, scratch1); - self.set_row(module, row_j, col_i, &tmp); - }); - }); - } - - pub fn external_product, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWESwitchingKey, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - lhs.rank_out(), - rhs.rank(), - "ksk_in output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - assert_eq!( - self.rank_out(), - rhs.rank(), - "ksk_out output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - } - - let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_in); - tmp_out.external_product(module, &tmp_in, rhs, scratch2); - self.set_row(module, row_j, col_i, &tmp_out); - }); - }); - - tmp_out.data.zero(); - - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank_in()).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_out); - }); - }); - } - - pub fn external_product_inplace>( - &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_out(), - rhs.rank(), - "ksk_out output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - } - - let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - println!("tmp: {}", tmp.size()); - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - self.get_row(module, row_j, col_i, &mut tmp); - tmp.external_product_inplace(module, rhs, scratch1); - self.set_row(module, row_j, col_i, &tmp); - }); - }); - } -} diff --git a/core/src/lib.rs b/core/src/lib.rs index ef181a1..fa6d009 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,35 +1,29 @@ -pub mod automorphism; pub mod blind_rotation; +pub mod dist; pub mod elem; pub mod fourier_glwe; pub mod gglwe; pub mod ggsw; pub mod glwe; -pub mod keys; -pub mod keyswitch_key; pub mod lwe; -pub mod tensor_key; -#[cfg(test)] -mod test_fft64; +pub mod noise; mod utils; -pub use automorphism::*; use backend::Backend; use backend::FFT64; use backend::Module; pub use elem::*; -pub use fourier_glwe::*; -pub use gglwe::*; +pub use fourier_glwe::{FourierGLWECiphertext, FourierGLWESecret}; +pub use gglwe::{GGLWECiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey}; pub use ggsw::*; -pub use glwe::*; -pub use keyswitch_key::*; +pub use glwe::{GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWEPublicKey, GLWESecret}; +pub(crate) use glwe::{GLWECiphertextToMut, GLWECiphertextToRef}; pub use lwe::*; -pub use tensor_key::*; pub use backend::Scratch; pub use backend::ScratchOwned; -use crate::keys::SecretDistribution; +use crate::dist::Distribution; pub(crate) const SIX_SIGMA: f64 = 6.0; @@ -62,7 +56,8 @@ pub trait ScratchCore { k: usize, rank: usize, ) -> (FourierGLWECiphertext<&mut [u8], B>, &mut Self); - fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8], B>, &mut Self); + fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); + fn tmp_fourier_sk(&mut self, module: &Module, rank: usize) -> (FourierGLWESecret<&mut [u8], B>, &mut Self); fn tmp_glwe_pk( &mut self, module: &Module, @@ -88,7 +83,7 @@ pub trait ScratchCore { rows: usize, digits: usize, rank: usize, - ) -> (TensorKey<&mut [u8], B>, &mut Self); + ) -> (GLWETensorKey<&mut [u8], B>, &mut Self); fn tmp_autokey( &mut self, module: &Module, @@ -97,7 +92,7 @@ pub trait ScratchCore { rows: usize, digits: usize, rank: usize, - ) -> (AutomorphismKey<&mut [u8], B>, &mut Self); + ) -> (GLWEAutomorphismKey<&mut [u8], B>, &mut Self); } impl ScratchCore for Scratch { @@ -194,22 +189,31 @@ impl ScratchCore for Scratch { ( GLWEPublicKey { data, - dist: SecretDistribution::NONE, + dist: Distribution::NONE, }, scratch, ) } - fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8], FFT64>, &mut Self) { + fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self) { let (data, scratch) = self.tmp_scalar_znx(module, rank); - let (data_fourier, scratch1) = scratch.tmp_scalar_znx_dft(module, rank); ( GLWESecret { data, - data_fourier, - dist: SecretDistribution::NONE, + dist: Distribution::NONE, }, - scratch1, + scratch, + ) + } + + fn tmp_fourier_sk(&mut self, module: &Module, rank: usize) -> (FourierGLWESecret<&mut [u8], FFT64>, &mut Self) { + let (data, scratch) = self.tmp_scalar_znx_dft(module, rank); + ( + FourierGLWESecret { + data, + dist: Distribution::NONE, + }, + scratch, ) } @@ -235,9 +239,9 @@ impl ScratchCore for Scratch { rows: usize, digits: usize, rank: usize, - ) -> (AutomorphismKey<&mut [u8], FFT64>, &mut Self) { + ) -> (GLWEAutomorphismKey<&mut [u8], FFT64>, &mut Self) { let (data, scratch) = self.tmp_glwe_ksk(module, basek, k, rows, digits, rank, rank); - (AutomorphismKey { key: data, p: 0 }, scratch) + (GLWEAutomorphismKey { key: data, p: 0 }, scratch) } fn tmp_tsk( @@ -248,7 +252,7 @@ impl ScratchCore for Scratch { rows: usize, digits: usize, rank: usize, - ) -> (TensorKey<&mut [u8], FFT64>, &mut Self) { + ) -> (GLWETensorKey<&mut [u8], FFT64>, &mut Self) { let mut keys: Vec> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); @@ -264,6 +268,6 @@ impl ScratchCore for Scratch { scratch = s; keys.push(gglwe); } - (TensorKey { keys }, scratch) + (GLWETensorKey { keys }, scratch) } } diff --git a/core/src/lwe/mod.rs b/core/src/lwe/mod.rs new file mode 100644 index 0000000..d91ce80 --- /dev/null +++ b/core/src/lwe/mod.rs @@ -0,0 +1,3 @@ +pub mod secret; + +pub use secret::LWESecret; diff --git a/core/src/lwe.rs b/core/src/lwe/secret.rs similarity index 73% rename from core/src/lwe.rs rename to core/src/lwe/secret.rs index 3f0d749..90776a7 100644 --- a/core/src/lwe.rs +++ b/core/src/lwe/secret.rs @@ -1,18 +1,18 @@ use backend::{ScalarZnx, ZnxInfos, ZnxZero}; use sampling::source::Source; -use crate::SecretDistribution; +use crate::Distribution; pub struct LWESecret { pub(crate) data: ScalarZnx, - pub(crate) dist: SecretDistribution, + pub(crate) dist: Distribution, } impl LWESecret> { pub fn alloc(n: usize) -> Self { Self { data: ScalarZnx::new(n, 1), - dist: SecretDistribution::NONE, + dist: Distribution::NONE, } } } @@ -34,31 +34,31 @@ impl LWESecret { impl + AsMut<[u8]>> LWESecret { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { self.data.fill_ternary_prob(0, prob, source); - self.dist = SecretDistribution::TernaryProb(prob); + self.dist = Distribution::TernaryProb(prob); } pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { self.data.fill_ternary_hw(0, hw, source); - self.dist = SecretDistribution::TernaryFixed(hw); + self.dist = Distribution::TernaryFixed(hw); } pub fn fill_binary_prob(&mut self, prob: f64, source: &mut Source) { self.data.fill_binary_prob(0, prob, source); - self.dist = SecretDistribution::BinaryProb(prob); + self.dist = Distribution::BinaryProb(prob); } pub fn fill_binary_hw(&mut self, hw: usize, source: &mut Source) { self.data.fill_binary_hw(0, hw, source); - self.dist = SecretDistribution::BinaryFixed(hw); + self.dist = Distribution::BinaryFixed(hw); } pub fn fill_binary_block(&mut self, block_size: usize, source: &mut Source) { self.data.fill_binary_block(0, block_size, source); - self.dist = SecretDistribution::BinaryBlock(block_size); + self.dist = Distribution::BinaryBlock(block_size); } pub fn fill_zero(&mut self) { self.data.zero(); - self.dist = SecretDistribution::ZERO; + self.dist = Distribution::ZERO; } } diff --git a/core/src/test_fft64/mod.rs b/core/src/noise.rs similarity index 97% rename from core/src/test_fft64/mod.rs rename to core/src/noise.rs index 4c0f513..cfc7698 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/noise.rs @@ -1,9 +1,4 @@ -mod automorphism_key; -mod gglwe; -mod ggsw; -mod glwe_fourier; -mod tensor_key; - +#[allow(dead_code)] pub(crate) fn var_noise_gglwe_product( n: f64, basek: usize, @@ -34,6 +29,7 @@ pub(crate) fn var_noise_gglwe_product( noise } +#[allow(dead_code)] pub(crate) fn log2_std_noise_gglwe_product( n: f64, basek: usize, @@ -62,6 +58,7 @@ pub(crate) fn log2_std_noise_gglwe_product( noise.log2().min(-1.0).max(-(a_logq as f64)) // max noise is [-2^{-1}, 2^{-1}] } +#[allow(dead_code)] pub(crate) fn noise_ggsw_product( n: f64, basek: usize, @@ -94,6 +91,7 @@ pub(crate) fn noise_ggsw_product( noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] } +#[allow(dead_code)] pub(crate) fn noise_ggsw_keyswitch( n: f64, basek: usize, From 829b8be610bb518765be3eb857e58ffe09683c2d Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 13 Jun 2025 20:45:24 +0200 Subject: [PATCH 08/23] wip on BR + added enc/dec for LWE --- backend/src/lib.rs | 4 +- backend/src/mat_znx_dft_ops.rs | 136 +++++++- backend/src/scalar_znx.rs | 4 +- backend/src/vec_znx.rs | 10 +- core/benches/external_product_glwe_fft64.rs | 12 +- core/src/blind_rotation/ccgi.rs | 139 ++++++++ core/src/blind_rotation/key.rs | 39 ++- core/src/blind_rotation/mod.rs | 4 + core/src/blind_rotation/test_fft64/cggi.rs | 85 +++++ core/src/blind_rotation/test_fft64/mod.rs | 1 + core/src/elem.rs | 2 +- core/src/fourier_glwe.rs | 321 ------------------ core/src/fourier_glwe/ciphertext.rs | 6 +- core/src/fourier_glwe/decryption.rs | 4 +- core/src/fourier_glwe/encryption.rs | 6 +- core/src/fourier_glwe/external_product.rs | 8 +- .../test_fft64/external_product.rs | 10 +- core/src/fourier_glwe/test_fft64/keyswitch.rs | 10 +- core/src/gglwe.rs | 236 ------------- core/src/gglwe/automorphism.rs | 52 +-- core/src/gglwe/ciphertext.rs | 6 +- core/src/gglwe/encryption.rs | 4 +- core/src/gglwe/test_fft64/automorphism_key.rs | 2 +- core/src/gglwe/test_fft64/gglwe.rs | 1 - core/src/ggsw/ciphertext.rs | 6 +- core/src/ggsw/test_fft64/ggsw.rs | 2 +- core/src/glwe/ciphertext.rs | 22 +- core/src/glwe/decryption.rs | 18 +- core/src/glwe/encryption.rs | 10 +- core/src/glwe/external_product.rs | 14 +- core/src/glwe/keyswitch.rs | 8 +- core/src/glwe/plaintext.rs | 2 +- core/src/glwe/test_fft64/automorphism.rs | 10 +- core/src/glwe/test_fft64/external_product.rs | 4 +- core/src/glwe/test_fft64/keyswitch.rs | 10 +- core/src/glwe/test_fft64/packing.rs | 2 +- core/src/glwe/test_fft64/trace.rs | 2 +- core/src/lib.rs | 9 +- core/src/lwe/ciphertext.rs | 77 +++++ core/src/lwe/decryption.rs | 21 ++ core/src/lwe/encryption.rs | 35 ++ core/src/lwe/mod.rs | 6 + core/src/lwe/plaintext.rs | 73 ++++ 43 files changed, 745 insertions(+), 688 deletions(-) create mode 100644 core/src/blind_rotation/test_fft64/cggi.rs create mode 100644 core/src/blind_rotation/test_fft64/mod.rs delete mode 100644 core/src/fourier_glwe.rs delete mode 100644 core/src/gglwe.rs create mode 100644 core/src/lwe/ciphertext.rs create mode 100644 core/src/lwe/decryption.rs create mode 100644 core/src/lwe/encryption.rs create mode 100644 core/src/lwe/plaintext.rs diff --git a/backend/src/lib.rs b/backend/src/lib.rs index dcf4325..09e5556 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -103,7 +103,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { assert_eq!( - (size * size_of::()) % align, + (size * size_of::()) % (align/ size_of::()), 0, "size={} must be a multiple of align={}", size, @@ -121,7 +121,7 @@ pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { /// of [DEFAULTALIGN]/size_of::() that is equal or greater to `size`. pub fn alloc_aligned(size: usize) -> Vec { alloc_aligned_custom::( - size + (size % (DEFAULTALIGN / size_of::())), + size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::()))), DEFAULTALIGN, ) } diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs index 5f08a89..b48cb1a 100644 --- a/backend/src/mat_znx_dft_ops.rs +++ b/backend/src/mat_znx_dft_ops.rs @@ -82,6 +82,12 @@ pub trait MatZnxDftOps { where A: MatZnxToMut; + /// Multiplies A by (X^{k} - 1). + fn mat_znx_dft_mul_x_pow_minus_one_add_inplace(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: MatZnxToMut, + A: MatZnxToRef; + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// @@ -212,7 +218,7 @@ impl MatZnxDftOps for Module { self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_1); }); - }) + }); } fn mat_znx_dft_mul_x_pow_minus_one_inplace(&self, k: i64, a: &mut A, scratch: &mut Scratch) @@ -249,7 +255,52 @@ impl MatZnxDftOps for Module { self.mat_znx_dft_set_row(&mut a, row_i, col_j, &tmp_1); }); - }) + }); + } + + fn mat_znx_dft_mul_x_pow_minus_one_add_inplace(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: MatZnxToMut, + A: MatZnxToRef, + { + let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: MatZnxDft<&[u8], FFT64> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + + let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1); + + { + let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1); + xpm1.data[0] = 1; + self.vec_znx_rotate_inplace(k, &mut xpm1, 0); + self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0); + } + + let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size()); + let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size()); + + (0..a.rows()).for_each(|row_i| { + (0..a.cols_in()).for_each(|col_j| { + self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j); + + (0..tmp_0.cols()).for_each(|i| { + self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i); + self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i); + }); + + self.mat_znx_dft_get_row(&mut tmp_0, &res, row_i, col_j); + + (0..tmp_0.cols()).for_each(|i| { + self.vec_znx_dft_add_inplace(&mut tmp_0, i, &tmp_1, i); + }); + + self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_0); + }); + }); } fn mat_znx_dft_set_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) @@ -845,7 +896,6 @@ mod tests { (0..cols_out).for_each(|j| { module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); - // module.vec_znx_big_normalize(basek, &mut want, j, &tmp_big, 0, scratch.borrow()); module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow()); module.vec_znx_rotate(k, &mut want, j, &tmp, 0); module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0); @@ -863,4 +913,84 @@ mod tests { }); }); } + + #[test] + fn mat_znx_dft_mul_x_pow_minus_one_add_inplace() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let basek: usize = 8; + let rows: usize = 2; + let cols_in: usize = 2; + let cols_out: usize = 2; + let size: usize = 4; + + let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out)); + + let mut mat_want: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); + let mut mat_have: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); + + let mut tmp: VecZnx> = module.new_vec_znx(1, size); + let mut tmp_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(cols_out, size); + + let mut source: Source = Source::new([0u8; 32]); + + (0..mat_have.rows()).for_each(|row_i| { + (0..mat_have.cols_in()).for_each(|col_i| { + (0..cols_out).for_each(|j| { + tmp.fill_uniform(basek, 0, size, &mut source); + module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0); + }); + + module.mat_znx_dft_set_row(&mut mat_have, row_i, col_i, &tmp_dft); + }); + }); + + (0..mat_want.rows()).for_each(|row_i| { + (0..mat_want.cols_in()).for_each(|col_i| { + (0..cols_out).for_each(|j| { + tmp.fill_uniform(basek, 0, size, &mut source); + module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0); + }); + + module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft); + }); + }); + + let k: i64 = 1; + + module.mat_znx_dft_mul_x_pow_minus_one_add_inplace(k, &mut mat_have, &mat_want, scratch.borrow()); + + let mut have: VecZnx> = module.new_vec_znx(cols_out, size); + let mut want: VecZnx> = module.new_vec_znx(cols_out, size); + let mut tmp_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, size); + + let mut source: Source = Source::new([0u8; 32]); + (0..mat_want.rows()).for_each(|row_i| { + (0..mat_want.cols_in()).for_each(|col_i| { + module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i); + + (0..cols_out).for_each(|j| { + module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); + module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow()); + module.vec_znx_rotate(k, &mut want, j, &tmp, 0); + module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0); + + tmp.fill_uniform(basek, 0, size, &mut source); + module.vec_znx_add_inplace(&mut want, j, &tmp, 0); + module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow()); + }); + + module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i); + + (0..cols_out).for_each(|j| { + module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); + module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow()); + }); + + assert_eq!(have, want) + }); + }); + } } diff --git a/backend/src/scalar_znx.rs b/backend/src/scalar_znx.rs index 09e7292..0a5bb64 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/scalar_znx.rs @@ -91,9 +91,9 @@ impl + AsRef<[u8]>> ScalarZnx { } pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) { - assert!(block_size & (block_size - 1) == 0); + assert!(self.n() % block_size == 0); let max_idx: u64 = (block_size + 1) as u64; - let mask_idx: u64 = (2 * block_size - 1) as u64; + let mask_idx: u64 = (1<<((u64::BITS - max_idx.leading_zeros())as u64)) - 1 ; for block in self.at_mut(col, 0).chunks_mut(block_size) { let idx: usize = source.next_u64n(max_idx, mask_idx) as usize; if idx != block_size { diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index 84b9a84..8213d5e 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -177,7 +177,7 @@ impl>> VecZnx { n * cols * size * size_of::() } - pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self { + pub fn new(n: usize, cols: usize, size: usize) -> Self { let data = alloc_aligned::(Self::bytes_of::(n, cols, size)); Self { data: data.into(), @@ -243,7 +243,13 @@ fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } -#[allow(dead_code)] +impl + AsMut<[u8]>> VecZnx{ + pub fn normalize(&mut self, basek: usize, a_col: usize, tmp_bytes: &mut [u8]){ + normalize(basek, self, a_col, tmp_bytes); + } +} + + fn normalize + AsRef<[u8]>>(basek: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index c48c626..fd6508a 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -26,7 +26,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let rank: usize = p.rank; let digits: usize = 1; - let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek; + let rows: usize = 1; //(p.k_ct_in + p.basek - 1) / p.basek; let sigma: f64 = 3.2; let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); @@ -81,11 +81,11 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { } let params_set: Vec = vec![Params { - log_n: 10, - basek: 7, - k_ct_in: 27, - k_ct_out: 27, - k_ggsw: 27, + log_n: 11, + basek: 22, + k_ct_in: 44, + k_ct_out: 44, + k_ggsw: 54, rank: 1, }]; diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index e69de29..341045e 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -0,0 +1,139 @@ +use std::time::Instant; + +use backend::{MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero, FFT64}; +use itertools::izip; + +use crate::{ + GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEPlaintext, Infos, LWECiphertext, + ScratchCore, blind_rotation::key::BlindRotationKeyCGGI, lwe::ciphertext::LWECiphertextToRef, +}; + +pub fn cggi_blind_rotate_scratch_space( + module: &Module, + basek: usize, + k_lut: usize, + k_brk: usize, + rows: usize, + rank: usize, +) -> usize { + let size = k_brk.div_ceil(basek); + GGSWCiphertext::, FFT64>::bytes_of(module, basek, k_brk, rows, 1, rank) + + (module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, rank + 1) + | GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_lut, k_brk, 1, rank)) +} + +pub fn cggi_blind_rotate( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &GLWEPlaintext, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, + DataLUT: AsRef<[u8]>, +{ + + println!("{}", lwe.n()); + + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); + let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let lut_ref: GLWECiphertext<&[u8]> = lut.to_ref(); + + let cols = out_mut.rank()+1; + + mod_switch_2n(module, &mut lwe_2n, &lwe_ref); + + let a: &[i64] = &lwe_2n[1..]; + let b: i64 = lwe_2n[0]; + + out_mut.data.zero(); + + // Initialize out to X^{b} * LUT(X) + module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut_ref.data, 0); + + let block_size: usize = brk.block_size(); + + // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] + + let (mut acc_dft, scratch1) = scratch.tmp_glwe_fourier(module, brk.basek(), out_mut.k(), out_mut.rank()); + let (mut acc_add_dft, scratch2) = scratch1.tmp_glwe_fourier(module, brk.basek(), out_mut.k(), out_mut.rank()); + let (mut vmp_res, scratch3) = scratch2.tmp_vec_znx_dft(module, acc_dft.rank()+1, acc_dft.size()); + let (mut xai_minus_one, scratch4) = scratch3.tmp_scalar_znx(module, 1); + let (mut xai_minus_one_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); + + let start: Instant = Instant::now(); + izip!( + a.chunks_exact(block_size), + brk.data.chunks_exact(block_size) + ) + .for_each(|(ai, ski)| { + + out_mut.dft(module, &mut acc_dft); + acc_add_dft.data.zero(); + + izip!(ai.iter(), ski.iter()) + .enumerate() + .for_each(|(i, (aii, skii))| { + + // vmp_res = DFT(acc) * BRK[i] + module.vmp_apply(&mut vmp_res, &acc_dft.data, &skii.data, scratch5); + + // DFT(X^ai -1) + xai_minus_one.zero(); + xai_minus_one.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(*aii, &mut xai_minus_one, 0); + xai_minus_one.at_mut(0, 0)[0] -= 1; + module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + + // DFT(X^ai -1) * (DFT(acc) * BRK[i]) + (0..cols).for_each(|i|{ + module.svp_apply_inplace(&mut vmp_res, i, &xai_minus_one_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft.data, i, &vmp_res, i); + }); + + }); + + acc_add_dft.idft(module, &mut out_mut, scratch5); + }); + let duration: std::time::Duration = start.elapsed(); + println!("external products: {} us", duration.as_micros()); +} + +fn mod_switch_2n(module: &Module, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { + let basek: usize = lwe.basek(); + + let log2n: usize = module.log_n() + 1; + + res.copy_from_slice(&lwe.data.at(0, 0)); + + if basek > log2n { + let diff: usize = basek - log2n; + res.iter_mut().for_each(|x| { + *x = div_signed_by_pow2(x, diff); + }) + } else { + let rem: usize = basek - (log2n % basek); + let size: usize = log2n.div_ceil(basek); + (1..size).for_each(|i| { + if i == size - 1 && rem != basek { + let k_rem: usize = basek - rem; + izip!(lwe.data.at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(lwe.data.at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { + *y = (*y << basek) + x; + }); + } + }) + } +} + +#[inline(always)] +fn div_signed_by_pow2(x: &i64, k: usize) -> i64 { + let bias: i64 = (1 << k) - 1; + (x + ((x >> 63) & bias)) >> k +} diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index 270080d..9f23c61 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -1,22 +1,22 @@ use backend::{Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}; use sampling::source::Source; -use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, GLWEAutomorphismKey, LWESecret}; +use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, Infos, LWESecret}; pub struct BlindRotationKeyCGGI { pub(crate) data: Vec, B>>, pub(crate) dist: Distribution, } -pub struct BlindRotationKeyFHEW { - pub(crate) data: Vec, B>>, - pub(crate) auto: Vec, B>>, -} +// pub struct BlindRotationKeyFHEW { +// pub(crate) data: Vec, B>>, +// pub(crate) auto: Vec, B>>, +//} impl BlindRotationKeyCGGI { - pub fn allocate(module: &Module, lwe_degree: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { - let mut data: Vec, FFT64>> = Vec::with_capacity(lwe_degree); - (0..lwe_degree).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank))); + pub fn allocate(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + let mut data: Vec, FFT64>> = Vec::with_capacity(n_lwe); + (0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank))); Self { data, dist: Distribution::NONE, @@ -61,4 +61,27 @@ impl BlindRotationKeyCGGI { ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch); }) } + + pub(crate) fn block_size(&self) -> usize { + match self.dist { + Distribution::BinaryBlock(value) => value, + _ => 1, + } + } + + pub(crate) fn rows(&self) -> usize { + self.data[0].rows() + } + + pub(crate) fn k(&self) -> usize { + self.data[0].k() + } + + pub(crate) fn rank(&self) -> usize { + self.data[0].rank() + } + + pub(crate) fn basek(&self) -> usize { + self.data[0].basek() + } } diff --git a/core/src/blind_rotation/mod.rs b/core/src/blind_rotation/mod.rs index c531781..63fb3fd 100644 --- a/core/src/blind_rotation/mod.rs +++ b/core/src/blind_rotation/mod.rs @@ -1,2 +1,6 @@ // pub mod cggi; +pub mod ccgi; pub mod key; + +#[cfg(test)] +pub mod test_fft64; diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs new file mode 100644 index 0000000..8a9246f --- /dev/null +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -0,0 +1,85 @@ +use core::time; +use std::time::Instant; + +use backend::{Encoding, Module, ScratchOwned, FFT64}; +use sampling::source::Source; + +use crate::{ + blind_rotation::{ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space}, key::BlindRotationKeyCGGI}, lwe::LWEPlaintext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, LWECiphertext, LWESecret +}; + +#[test] +fn blind_rotation() { + let module: Module = Module::::new(2048); + let basek: usize = 17; + + let n_lwe: usize = 1071; + + let k_lwe: usize = 22; + let k_brk: usize = 54; + let rows_brk: usize = 1; + let k_lut: usize = 44; + let rank: usize = 1; + let block_size: usize = 7; + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + let sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_glwe); + + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + sk_lwe.fill_binary_block(block_size, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new(BlindRotationKeyCGGI::generate_from_sk_scratch_space( + &module, basek, k_brk, rank, + ) | cggi_blind_rotate_scratch_space(&module, basek, k_lut, k_brk, rows_brk, rank)); + + let start: Instant = Instant::now(); + let mut brk: BlindRotationKeyCGGI = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); + brk.generate_from_sk( + &module, + &sk_glwe_dft, + &sk_lwe, + &mut source_xa, + &mut source_xe, + 3.2, + scratch.borrow(), + ); + let duration: std::time::Duration = start.elapsed(); + println!("brk-gen: {} ms", duration.as_millis()); + + let mut lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe); + + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); + + pt_lwe.data.encode_coeff_i64(0, basek, 7, 0, 63, 7); + + println!("{}", pt_lwe.data); + + lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2); + + lwe.decrypt(&mut pt_lwe, &sk_lwe); + + println!("{}", pt_lwe.data); + + let lut: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); + + let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); + + let start: Instant = Instant::now(); + (0..32).for_each(|i|{ + cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch.borrow()); + }); + + let duration: std::time::Duration = start.elapsed(); + println!("blind-rotate: {} ms", duration.as_millis()); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); + + res.decrypt(&module , &mut pt, &sk_glwe_dft, scratch.borrow()); + + println!("{}", pt.data); +} diff --git a/core/src/blind_rotation/test_fft64/mod.rs b/core/src/blind_rotation/test_fft64/mod.rs new file mode 100644 index 0000000..1a23dff --- /dev/null +++ b/core/src/blind_rotation/test_fft64/mod.rs @@ -0,0 +1 @@ +pub mod cggi; diff --git a/core/src/elem.rs b/core/src/elem.rs index e659e1e..9a1de39 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,6 @@ use backend::{Backend, Module, ZnxInfos}; -use crate::{FourierGLWECiphertext, div_ceil}; +use crate::FourierGLWECiphertext; pub trait Infos { type Inner: ZnxInfos; diff --git a/core/src/fourier_glwe.rs b/core/src/fourier_glwe.rs deleted file mode 100644 index 0e024a9..0000000 --- a/core/src/fourier_glwe.rs +++ /dev/null @@ -1,321 +0,0 @@ -use backend::{ - Backend, FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, - VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxZero, -}; -use sampling::source::Source; - -use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore}; - -pub struct FourierGLWECiphertext { - pub data: VecZnxDft, - pub basek: usize, - pub k: usize, -} - -impl FourierGLWECiphertext, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { - Self { - data: module.new_vec_znx_dft(rank + 1, k.div_ceil(basek)), - basek: basek, - k: k, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx_dft(rank + 1, k.div_ceil(basek)) - } -} - -impl Infos for FourierGLWECiphertext { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl FourierGLWECiphertext { - pub fn rank(&self) -> usize { - self.cols() - 1 - } -} - -impl FourierGLWECiphertext, FFT64> { - #[allow(dead_code)] - pub(crate) fn idft_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - module.bytes_of_vec_znx(1, k.div_ceil(basek)) - + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) - } - - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - } - - pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = k.div_ceil(basek); - (module.vec_znx_big_normalize_tmp_bytes() - | module.bytes_of_vec_znx_dft(1, size) - | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) - + module.bytes_of_vec_znx_big(1, size) - } - - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - GLWECiphertext::bytes_of(module, basek, k_out, rank_out) - + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) - } - - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) - } - - // WARNING TODO: UPDATE - pub fn external_product_scratch_space( - module: &Module, - basek: usize, - _k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - let ggsw_size: usize = k_ggsw.div_ceil(basek); - let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); - let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); - let ggsw_size: usize = k_ggsw.div_ceil(basek); - let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) - + module.vmp_apply_tmp_bytes(ggsw_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); - let res_small: usize = module.bytes_of_vec_znx(rank + 1, ggsw_size); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - res_dft + (vmp | (res_small + normalize)) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) - } -} - -impl + AsRef<[u8]>> FourierGLWECiphertext { - pub fn encrypt_zero_sk>( - &mut self, - module: &Module, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); - tmp_ct.encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch1); - tmp_ct.dft(module, self); - } - - pub fn keyswitch, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &FourierGLWECiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); - tmp_ct.keyswitch_from_fourier(module, lhs, rhs, scratch1); - tmp_ct.dft(module, self); - } - - pub fn keyswitch_inplace>( - &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; - self.keyswitch(&module, &*self_ptr, rhs, scratch); - } - } - - pub fn external_product, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &FourierGLWECiphertext, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(rhs.rank(), lhs.rank()); - assert_eq!(rhs.rank(), self.rank()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= FourierGLWECiphertext::external_product_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank(), - ) - ); - } - - let cols: usize = rhs.rank() + 1; - let digits = rhs.digits(); - - // Space for VMP result in DFT domain and high precision - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); - - { - (0..digits).for_each(|di| { - a_dft.set_size((lhs.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - - (0..cols).for_each(|col_i| { - module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); - } - }); - } - - // VMP result in high precision - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); - - // Space for VMP result normalized - let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); - (0..cols).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); - module.vec_znx_dft(1, 0, &mut self.data, i, &res_small, i); - }); - } - - pub fn external_product_inplace>( - &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; - self.external_product(&module, &*self_ptr, rhs, scratch); - } - } -} - -impl> FourierGLWECiphertext { - pub fn decrypt + AsMut<[u8]>, DataSk: AsRef<[u8]>>( - &self, - module: &Module, - pt: &mut GLWEPlaintext, - sk: &GLWESecret, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert_eq!(sk.n(), module.n()); - } - - let cols = self.rank() + 1; - - let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct - pt_big.zero(); - - { - (1..cols).for_each(|i| { - let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct - module.svp_apply(&mut ci_dft, 0, &sk.data_fourier, i - 1, &self.data, i); - let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); - module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0); - }); - } - - { - let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size()); - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_idft(&mut c0_big, 0, &self.data, 0, scratch_2); - module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0); - } - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut pt_big, 0, scratch_1); - - pt.basek = self.basek(); - pt.k = pt.k().min(self.k()); - } - - #[allow(dead_code)] - pub(crate) fn idft + AsMut<[u8]>>( - &self, - module: &Module, - res: &mut GLWECiphertext, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), res.rank()); - assert_eq!(self.basek(), res.basek()) - } - - let min_size: usize = self.size().min(res.size()); - - let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size); - - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_idft(&mut res_big, 0, &self.data, i, scratch1); - module.vec_znx_big_normalize(self.basek(), &mut res.data, i, &res_big, 0, scratch1); - }); - } -} diff --git a/core/src/fourier_glwe/ciphertext.rs b/core/src/fourier_glwe/ciphertext.rs index 425191f..a742e31 100644 --- a/core/src/fourier_glwe/ciphertext.rs +++ b/core/src/fourier_glwe/ciphertext.rs @@ -1,6 +1,6 @@ use backend::{Backend, Module, VecZnxDft, VecZnxDftAlloc}; -use crate::{Infos, div_ceil}; +use crate::Infos; pub struct FourierGLWECiphertext { pub data: VecZnxDft, @@ -11,14 +11,14 @@ pub struct FourierGLWECiphertext { impl FourierGLWECiphertext, B> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.new_vec_znx_dft(rank + 1, div_ceil(k, basek)), + data: module.new_vec_znx_dft(rank + 1, k.div_ceil(basek)), basek: basek, k: k, } } pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx_dft(rank + 1, div_ceil(k, basek)) + module.bytes_of_vec_znx_dft(rank + 1, k.div_ceil(basek)) } } diff --git a/core/src/fourier_glwe/decryption.rs b/core/src/fourier_glwe/decryption.rs index 882be61..6c18383 100644 --- a/core/src/fourier_glwe/decryption.rs +++ b/core/src/fourier_glwe/decryption.rs @@ -3,11 +3,11 @@ use backend::{ VecZnxDftOps, ZnxZero, }; -use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, Infos, div_ceil}; +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, Infos}; impl FourierGLWECiphertext, FFT64> { pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = div_ceil(k, basek); + let size: usize = k.div_ceil(basek); (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size) | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) diff --git a/core/src/fourier_glwe/encryption.rs b/core/src/fourier_glwe/encryption.rs index d23ff4a..fd08709 100644 --- a/core/src/fourier_glwe/encryption.rs +++ b/core/src/fourier_glwe/encryption.rs @@ -1,17 +1,17 @@ use backend::{FFT64, Module, Scratch, VecZnxAlloc, VecZnxBigScratch, VecZnxDftOps}; use sampling::source::Source; -use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, Infos, ScratchCore, div_ceil}; +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, Infos, ScratchCore}; impl FourierGLWECiphertext, FFT64> { #[allow(dead_code)] pub(crate) fn idft_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - module.bytes_of_vec_znx(1, div_ceil(k, basek)) + module.bytes_of_vec_znx(1, k.div_ceil(basek)) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) } pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, div_ceil(k, basek)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) } } diff --git a/core/src/fourier_glwe/external_product.rs b/core/src/fourier_glwe/external_product.rs index 116416b..01a7371 100644 --- a/core/src/fourier_glwe/external_product.rs +++ b/core/src/fourier_glwe/external_product.rs @@ -3,7 +3,7 @@ use backend::{ VecZnxDftAlloc, VecZnxDftOps, }; -use crate::{FourierGLWECiphertext, GGSWCiphertext, Infos, div_ceil}; +use crate::{FourierGLWECiphertext, GGSWCiphertext, Infos}; impl FourierGLWECiphertext, FFT64> { // WARNING TODO: UPDATE @@ -16,10 +16,10 @@ impl FourierGLWECiphertext, FFT64> { digits: usize, rank: usize, ) -> usize { - let ggsw_size: usize = div_ceil(k_ggsw, basek); + let ggsw_size: usize = k_ggsw.div_ceil(basek); let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); - let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); - let ggsw_size: usize = div_ceil(k_ggsw, basek); + let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); + let ggsw_size: usize = k_ggsw.div_ceil(basek); let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + module.vmp_apply_tmp_bytes(ggsw_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); let res_small: usize = module.bytes_of_vec_znx(rank + 1, ggsw_size); diff --git a/core/src/fourier_glwe/test_fft64/external_product.rs b/core/src/fourier_glwe/test_fft64/external_product.rs index 2228d29..80c9c9a 100644 --- a/core/src/fourier_glwe/test_fft64/external_product.rs +++ b/core/src/fourier_glwe/test_fft64/external_product.rs @@ -1,6 +1,6 @@ use crate::{ FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, Infos, - div_ceil, noise::noise_ggsw_product, + noise::noise_ggsw_product, }; use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; use sampling::source::Source; @@ -10,7 +10,7 @@ fn apply() { let log_n: usize = 8; let basek: usize = 12; let k_in: usize = 45; - let digits: usize = div_ceil(k_in, basek); + let digits: usize = k_in.div_ceil(basek); (1..4).for_each(|rank| { (1..digits + 1).for_each(|di| { let k_ggsw: usize = k_in + basek * di; @@ -26,7 +26,7 @@ fn apply_inplace() { let log_n: usize = 8; let basek: usize = 12; let k_ct: usize = 60; - let digits: usize = div_ceil(k_ct, basek); + let digits: usize = k_ct.div_ceil(basek); (1..4).for_each(|rank| { (1..digits + 1).for_each(|di| { let k_ggsw: usize = k_ct + basek * di; @@ -39,7 +39,7 @@ fn apply_inplace() { fn test_apply(log_n: usize, basek: usize, k_out: usize, k_in: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = div_ceil(k_in, digits * basek); + let rows: usize = k_in.div_ceil(digits * basek); let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); @@ -147,7 +147,7 @@ fn test_apply(log_n: usize, basek: usize, k_out: usize, k_in: usize, k_ggsw: usi fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = div_ceil(k_ct, digits * basek); + let rows: usize = k_ct.div_ceil(digits * basek); let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); diff --git a/core/src/fourier_glwe/test_fft64/keyswitch.rs b/core/src/fourier_glwe/test_fft64/keyswitch.rs index a459964..50d56ee 100644 --- a/core/src/fourier_glwe/test_fft64/keyswitch.rs +++ b/core/src/fourier_glwe/test_fft64/keyswitch.rs @@ -1,5 +1,5 @@ use crate::{ - FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, div_ceil, + FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, noise::log2_std_noise_gglwe_product, }; use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; @@ -10,7 +10,7 @@ fn apply() { let log_n: usize = 8; let basek: usize = 12; let k_in: usize = 45; - let digits: usize = div_ceil(k_in, basek); + let digits: usize = k_in.div_ceil(basek); (1..4).for_each(|rank_in| { (1..4).for_each(|rank_out| { (1..digits + 1).for_each(|di| { @@ -31,7 +31,7 @@ fn apply_inplace() { let log_n: usize = 8; let basek: usize = 12; let k_ct: usize = 45; - let digits: usize = div_ceil(k_ct, basek); + let digits: usize = k_ct.div_ceil(basek); (1..4).for_each(|rank| { (1..digits + 1).for_each(|di| { let k_ksk: usize = k_ct + basek * di; @@ -54,7 +54,7 @@ fn test_apply( ) { let module: Module = Module::::new(1 << log_n); - let rows: usize = div_ceil(k_in, basek * digits); + let rows: usize = k_in.div_ceil(basek * digits); let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); @@ -152,7 +152,7 @@ fn test_apply( fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = div_ceil(k_ct, basek * digits); + let rows: usize = k_ct.div_ceil(basek * digits); let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); diff --git a/core/src/gglwe.rs b/core/src/gglwe.rs deleted file mode 100644 index 66a5238..0000000 --- a/core/src/gglwe.rs +++ /dev/null @@ -1,236 +0,0 @@ -use backend::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, Module, ScalarZnx, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxOps, - ZnxInfos, ZnxZero, -}; -use sampling::source::Source; - -use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESecret, GetRow, Infos, ScratchCore, SetRow, div_ceil}; - -pub struct GGLWECiphertext { - pub(crate) data: MatZnxDft, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) digits: usize, -} - -impl GGLWECiphertext, B> { - pub fn alloc( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self { - let size: usize = k.div_ceil(basek); - debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits - ); - - debug_assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size - ); - - Self { - data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, size), - basek: basek, - k, - digits, - } - } - - pub fn bytes_of( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - let size: usize = k.div_ceil(basek); - debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits - ); - - assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size - ); - - module.bytes_of_mat_znx_dft(rows, rank_in, rank_out + 1, size) - } -} - -impl Infos for GGLWECiphertext { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGLWECiphertext { - pub fn rank(&self) -> usize { - self.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.digits - } - - pub fn rank_in(&self) -> usize { - self.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.data.cols_out() - 1 - } -} - -impl GGLWECiphertext, FFT64> { - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - let size = k.div_ceil(basek); - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - + module.bytes_of_vec_znx(rank + 1, size) - + module.bytes_of_vec_znx(1, size) - + module.bytes_of_vec_znx_dft(rank + 1, size) - } - - pub fn generate_from_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { - unimplemented!() - } -} - -impl + AsRef<[u8]>> GGLWECiphertext { - pub fn encrypt_sk, DataSk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: &ScalarZnx, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank_in(), pt.cols()); - assert_eq!(self.rank_out(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert!( - scratch.available() - >= GGLWECiphertext::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()), - "scratch.available: {} < GGLWECiphertext::generate_from_sk_scratch_space(module, self.rank()={}, \ - self.size()={}): {}", - scratch.available(), - self.rank(), - self.size(), - GGLWECiphertext::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()) - ); - assert!( - self.rows() * self.digits() * self.basek() <= self.k(), - "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}", - self.rows(), - self.digits(), - self.basek(), - self.rows() * self.digits() * self.basek(), - self.k() - ); - } - - let rows: usize = self.rows(); - let digits: usize = self.digits(); - let basek: usize = self.basek(); - let k: usize = self.k(); - let rank_in: usize = self.rank_in(); - let rank_out: usize = self.rank_out(); - - let (mut tmp_pt, scrach_1) = scratch.tmp_glwe_pt(module, basek, k); - let (mut tmp_ct, scrach_2) = scrach_1.tmp_glwe_ct(module, basek, k, rank_out); - let (mut tmp_ct_dft, scratch_3) = scrach_2.tmp_glwe_fourier(module, basek, k, rank_out); - - // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns - // - // Example for ksk rank 2 to rank 3: - // - // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) - // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) - // - // Example ksk rank 2 to rank 1 - // - // (-(a*s) + s0, a) - // (-(b*s) + s1, b) - (0..rank_in).for_each(|col_i| { - (0..rows).for_each(|row_i| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace( - &mut tmp_pt.data, - 0, - (digits - 1) + row_i * digits, - pt, - col_i, - ); - module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3); - - // rlwe encrypt of vec_znx_pt into vec_znx_ct - tmp_ct.encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, sigma, scratch_3); - - // Switch vec_znx_ct into DFT domain - tmp_ct.dft(module, &mut tmp_ct_dft); - - // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft - self.set_row(module, row_i, col_i, &tmp_ct_dft); - }); - }); - } -} - -impl> GetRow for GGLWECiphertext { - fn get_row + AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut FourierGLWECiphertext, - ) { - module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); - } -} - -impl + AsRef<[u8]>> SetRow for GGLWECiphertext { - fn set_row>( - &mut self, - module: &Module, - row_i: usize, - col_j: usize, - a: &FourierGLWECiphertext, - ) { - module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); - } -} diff --git a/core/src/gglwe/automorphism.rs b/core/src/gglwe/automorphism.rs index 5460511..e18e65a 100644 --- a/core/src/gglwe/automorphism.rs +++ b/core/src/gglwe/automorphism.rs @@ -62,24 +62,33 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { self.rank_out(), rhs.rank_out() ); + assert!( + self.k() <= lhs.k(), + "output k={} cannot be greater than input k={}", + self.k(), + lhs.k() + ) } let cols_out: usize = rhs.rank_out() + 1; - let (mut tmp_dft, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { - // Extracts relevant row - lhs.get_row(module, row_j, col_i, &mut tmp_dft); + let (mut tmp_idft_data, scratct1) = scratch.tmp_vec_znx_big(module, cols_out, self.size()); - // Get a VecZnxBig from scratch space - let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size()); + { + let (mut tmp_dft, scratch2) = scratct1.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - // Switches input outside of DFT - (0..cols_out).for_each(|i| { - module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); - }); + // Extracts relevant row + lhs.get_row(module, row_j, col_i, &mut tmp_dft); + + // Get a VecZnxBig from scratch space + + // Switches input outside of DFT + (0..cols_out).for_each(|i| { + module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); + }); + } // Consumes to small vec znx let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small(); @@ -97,20 +106,25 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { }; // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2); + tmp_idft.keyswitch_inplace(module, &rhs.key, scratct1); - // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) - // and switches back to DFT domain - (0..self.rank_out() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i); - module.vec_znx_dft(1, 0, &mut tmp_dft.data, i, &tmp_idft.data, i); - }); + { + let (mut tmp_dft, _) = scratct1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - // Sets back the relevant row - self.set_row(module, row_j, col_i, &tmp_dft); + // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) + // and switches back to DFT domain + (0..self.rank_out() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i); + module.vec_znx_dft(1, 0, &mut tmp_dft.data, i, &tmp_idft.data, i); + }); + + // Sets back the relevant row + self.set_row(module, row_j, col_i, &tmp_dft); + } }); }); + let (mut tmp_dft, _) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); tmp_dft.data.zero(); (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { diff --git a/core/src/gglwe/ciphertext.rs b/core/src/gglwe/ciphertext.rs index a4c2f1d..340b897 100644 --- a/core/src/gglwe/ciphertext.rs +++ b/core/src/gglwe/ciphertext.rs @@ -1,6 +1,6 @@ use backend::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, Module}; -use crate::{FourierGLWECiphertext, GetRow, Infos, SetRow, div_ceil}; +use crate::{FourierGLWECiphertext, GetRow, Infos, SetRow}; pub struct GGLWECiphertext { pub(crate) data: MatZnxDft, @@ -19,7 +19,7 @@ impl GGLWECiphertext, B> { rank_in: usize, rank_out: usize, ) -> Self { - let size: usize = div_ceil(k, basek); + let size: usize = k.div_ceil(basek); debug_assert!( size > digits, "invalid gglwe: ceil(k/basek): {} <= digits: {}", @@ -52,7 +52,7 @@ impl GGLWECiphertext, B> { rank_in: usize, rank_out: usize, ) -> usize { - let size: usize = div_ceil(k, basek); + let size: usize = k.div_ceil(basek); debug_assert!( size > digits, "invalid gglwe: ceil(k/basek): {} <= digits: {}", diff --git a/core/src/gglwe/encryption.rs b/core/src/gglwe/encryption.rs index 7c4838b..6c31b2a 100644 --- a/core/src/gglwe/encryption.rs +++ b/core/src/gglwe/encryption.rs @@ -5,12 +5,12 @@ use sampling::source::Source; use crate::{ FourierGLWESecret, GGLWECiphertext, GLWEAutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, GLWETensorKey, Infos, - ScratchCore, SetRow, div_ceil, + ScratchCore, SetRow, }; impl GGLWECiphertext, FFT64> { pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - let size = div_ceil(k, basek); + let size = k.div_ceil(basek); GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + module.bytes_of_vec_znx(rank + 1, size) + module.bytes_of_vec_znx(1, size) diff --git a/core/src/gglwe/test_fft64/automorphism_key.rs b/core/src/gglwe/test_fft64/automorphism_key.rs index c06f0be..61f6825 100644 --- a/core/src/gglwe/test_fft64/automorphism_key.rs +++ b/core/src/gglwe/test_fft64/automorphism_key.rs @@ -2,7 +2,7 @@ use backend::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - FourierGLWECiphertext, FourierGLWESecret, GLWEAutomorphismKey, GLWEPlaintext, GLWESecret, GetRow, Infos, div_ceil, + FourierGLWECiphertext, FourierGLWESecret, GLWEAutomorphismKey, GLWEPlaintext, GLWESecret, GetRow, Infos, noise::log2_std_noise_gglwe_product, }; diff --git a/core/src/gglwe/test_fft64/gglwe.rs b/core/src/gglwe/test_fft64/gglwe.rs index 39aad9f..0e3796f 100644 --- a/core/src/gglwe/test_fft64/gglwe.rs +++ b/core/src/gglwe/test_fft64/gglwe.rs @@ -3,7 +3,6 @@ use sampling::source::Source; use crate::{ FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, - div_ceil, noise::{log2_std_noise_gglwe_product, noise_ggsw_product}, }; diff --git a/core/src/ggsw/ciphertext.rs b/core/src/ggsw/ciphertext.rs index ac1fae5..12e9723 100644 --- a/core/src/ggsw/ciphertext.rs +++ b/core/src/ggsw/ciphertext.rs @@ -7,7 +7,7 @@ use sampling::source::Source; use crate::{ FourierGLWECiphertext, FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWESwitchingKey, GLWETensorKey, GetRow, - Infos, ScratchCore, SetRow, div_ceil, + Infos, ScratchCore, SetRow, }; pub struct GGSWCiphertext { @@ -17,8 +17,8 @@ pub struct GGSWCiphertext { pub(crate) digits: usize, } -impl GGSWCiphertext, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { +impl GGSWCiphertext, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let size: usize = k.div_ceil(basek); debug_assert!( size > digits, diff --git a/core/src/ggsw/test_fft64/ggsw.rs b/core/src/ggsw/test_fft64/ggsw.rs index 9219703..714ed2c 100644 --- a/core/src/ggsw/test_fft64/ggsw.rs +++ b/core/src/ggsw/test_fft64/ggsw.rs @@ -6,7 +6,7 @@ use sampling::source::Source; use crate::{ FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWEAutomorphismKey, GLWEPlaintext, GLWESecret, GLWESwitchingKey, - GLWETensorKey, GetRow, Infos, div_ceil, + GLWETensorKey, GetRow, Infos, noise::{noise_ggsw_keyswitch, noise_ggsw_product}, }; diff --git a/core/src/glwe/ciphertext.rs b/core/src/glwe/ciphertext.rs index ff634bd..d0fb39c 100644 --- a/core/src/glwe/ciphertext.rs +++ b/core/src/glwe/ciphertext.rs @@ -1,9 +1,6 @@ -use backend::{ - Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxToMut, - VecZnxToRef, -}; +use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxDftOps, VecZnxToMut, VecZnxToRef}; -use crate::{FourierGLWECiphertext, GLWEOps, Infos, SetMetaData, div_ceil}; +use crate::{FourierGLWECiphertext, GLWEOps, Infos, SetMetaData}; pub struct GLWECiphertext { pub data: VecZnx, @@ -14,14 +11,14 @@ pub struct GLWECiphertext { impl GLWECiphertext> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.new_vec_znx(rank + 1, div_ceil(k, basek)), + data: module.new_vec_znx(rank + 1, k.div_ceil(basek)), basek, k, } } pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, div_ceil(k, basek)) + module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) } } @@ -62,10 +59,13 @@ impl> GLWECiphertext { } } -impl GLWECiphertext> { - pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = div_ceil(k, basek); - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) +impl> GLWECiphertext { + pub fn clone(&self) -> GLWECiphertext> { + GLWECiphertext { + data: self.data.clone(), + basek: self.basek(), + k: self.k(), + } } } diff --git a/core/src/glwe/decryption.rs b/core/src/glwe/decryption.rs index eac91d6..e543963 100644 --- a/core/src/glwe/decryption.rs +++ b/core/src/glwe/decryption.rs @@ -1,16 +1,18 @@ -use backend::{FFT64, Module, ScalarZnxDftOps, Scratch, VecZnxBigOps, VecZnxDftOps, ZnxZero}; +use backend::{ + FFT64, Module, ScalarZnxDftOps, Scratch, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, + ZnxZero, +}; use crate::{FourierGLWESecret, GLWECiphertext, GLWEPlaintext, Infos}; -impl> GLWECiphertext { - pub fn clone(&self) -> GLWECiphertext> { - GLWECiphertext { - data: self.data.clone(), - basek: self.basek(), - k: self.k(), - } +impl GLWECiphertext> { + pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + let size: usize = k.div_ceil(basek); + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } +} +impl> GLWECiphertext { pub fn decrypt + AsRef<[u8]>, DataSk: AsRef<[u8]>>( &self, module: &Module, diff --git a/core/src/glwe/encryption.rs b/core/src/glwe/encryption.rs index 3b70d99..b0a7615 100644 --- a/core/src/glwe/encryption.rs +++ b/core/src/glwe/encryption.rs @@ -4,15 +4,15 @@ use backend::{ }; use sampling::source::Source; -use crate::{FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, Infos, SIX_SIGMA, dist::Distribution, div_ceil}; +use crate::{FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, Infos, SIX_SIGMA, dist::Distribution}; impl GLWECiphertext> { pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = div_ceil(k, basek); + let size: usize = k.div_ceil(basek); module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size) } pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = div_ceil(k, basek); + let size: usize = k.div_ceil(basek); ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1)) + module.bytes_of_scalar_znx_dft(1) + module.vec_znx_big_normalize_tmp_bytes() @@ -71,7 +71,7 @@ impl + AsMut<[u8]>> GLWECiphertext { sigma: f64, scratch: &mut Scratch, ) { - self.encrypt_pk_private( + self.encrypt_pk_private::( module, Some((pt, 0)), pk, @@ -91,7 +91,7 @@ impl + AsMut<[u8]>> GLWECiphertext { sigma: f64, scratch: &mut Scratch, ) { - self.encrypt_pk_private( + self.encrypt_pk_private::, DataPk>( module, None::<(&GLWEPlaintext>, usize)>, pk, diff --git a/core/src/glwe/external_product.rs b/core/src/glwe/external_product.rs index c44ab75..3ebf339 100644 --- a/core/src/glwe/external_product.rs +++ b/core/src/glwe/external_product.rs @@ -2,7 +2,7 @@ use backend::{ FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxBig, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxScratch, }; -use crate::{FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, Infos, div_ceil}; +use crate::{FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, Infos}; impl GLWECiphertext> { pub fn external_product_scratch_space( @@ -10,14 +10,14 @@ impl GLWECiphertext> { basek: usize, k_out: usize, k_in: usize, - ggsw_k: usize, + k_ggsw: usize, digits: usize, rank: usize, ) -> usize { let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); - let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); - let out_size: usize = div_ceil(k_out, basek); - let ggsw_size: usize = div_ceil(ggsw_k, basek); + let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); + let out_size: usize = k_out.div_ceil(basek); + let ggsw_size: usize = k_ggsw.div_ceil(basek); let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + module.vmp_apply_tmp_bytes( out_size, @@ -35,11 +35,11 @@ impl GLWECiphertext> { module: &Module, basek: usize, k_out: usize, - ggsw_k: usize, + k_ggsw: usize, digits: usize, rank: usize, ) -> usize { - Self::external_product_scratch_space(module, basek, k_out, k_out, ggsw_k, digits, rank) + Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) } } diff --git a/core/src/glwe/keyswitch.rs b/core/src/glwe/keyswitch.rs index eace187..5fb12fe 100644 --- a/core/src/glwe/keyswitch.rs +++ b/core/src/glwe/keyswitch.rs @@ -3,7 +3,7 @@ use backend::{ VecZnxDftOps, ZnxZero, }; -use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESwitchingKey, Infos, div_ceil}; +use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESwitchingKey, Infos}; impl GLWECiphertext> { pub fn keyswitch_scratch_space( @@ -17,9 +17,9 @@ impl GLWECiphertext> { rank_out: usize, ) -> usize { let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank_out + 1); - let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); - let out_size: usize = div_ceil(k_out, basek); - let ksk_size: usize = div_ceil(k_ksk, basek); + let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); + let out_size: usize = k_out.div_ceil(basek); + let ksk_size: usize = k_ksk.div_ceil(basek); let ai_dft: usize = module.bytes_of_vec_znx_dft(rank_in, in_size); let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) + module.bytes_of_vec_znx_dft(rank_in, in_size); diff --git a/core/src/glwe/plaintext.rs b/core/src/glwe/plaintext.rs index 9f24be0..5bebc68 100644 --- a/core/src/glwe/plaintext.rs +++ b/core/src/glwe/plaintext.rs @@ -1,6 +1,6 @@ use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; -use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData, div_ceil}; +use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData}; pub struct GLWEPlaintext { pub data: VecZnx, diff --git a/core/src/glwe/test_fft64/automorphism.rs b/core/src/glwe/test_fft64/automorphism.rs index 963083b..1b46e63 100644 --- a/core/src/glwe/test_fft64/automorphism.rs +++ b/core/src/glwe/test_fft64/automorphism.rs @@ -3,7 +3,7 @@ use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, div_ceil, + FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::log2_std_noise_gglwe_product, }; @@ -12,7 +12,7 @@ fn apply_inplace() { let log_n: usize = 8; let basek: usize = 12; let k_ct: usize = 60; - let digits: usize = div_ceil(k_ct, basek); + let digits: usize = k_ct.div_ceil(basek); (1..4).for_each(|rank| { (1..digits + 1).for_each(|di| { let k_ksk: usize = k_ct + basek * di; @@ -27,7 +27,7 @@ fn apply() { let log_n: usize = 8; let basek: usize = 12; let k_in: usize = 60; - let digits: usize = div_ceil(k_in, basek); + let digits: usize = k_in.div_ceil(basek); (1..4).for_each(|rank| { (1..digits + 1).for_each(|di| { let k_ksk: usize = k_in + basek * di; @@ -51,7 +51,7 @@ fn test_automorphism( ) { let module: Module = Module::::new(1 << log_n); - let rows: usize = div_ceil(k_in, basek * digits); + let rows: usize = k_in.div_ceil(basek * digits); let mut autokey: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); @@ -149,7 +149,7 @@ fn test_automorphism_inplace( ) { let module: Module = Module::::new(1 << log_n); - let rows: usize = div_ceil(k_ct, basek * digits); + let rows: usize = k_ct.div_ceil(basek * digits); let mut autokey: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); diff --git a/core/src/glwe/test_fft64/external_product.rs b/core/src/glwe/test_fft64/external_product.rs index 4ba77c3..aec23b6 100644 --- a/core/src/glwe/test_fft64/external_product.rs +++ b/core/src/glwe/test_fft64/external_product.rs @@ -2,7 +2,7 @@ use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwne use sampling::source::Source; use crate::{ - FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, div_ceil, noise::noise_ggsw_product, + FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::noise_ggsw_product, }; #[test] @@ -10,7 +10,7 @@ fn apply() { let log_n: usize = 8; let basek: usize = 12; let k_in: usize = 45; - let digits: usize = div_ceil(k_in, basek); + let digits: usize = k_in.div_ceil(basek); (1..4).for_each(|rank| { (1..digits + 1).for_each(|di| { let k_ggsw: usize = k_in + basek * di; diff --git a/core/src/glwe/test_fft64/keyswitch.rs b/core/src/glwe/test_fft64/keyswitch.rs index f80ae86..732ef03 100644 --- a/core/src/glwe/test_fft64/keyswitch.rs +++ b/core/src/glwe/test_fft64/keyswitch.rs @@ -2,7 +2,7 @@ use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, div_ceil, + FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, noise::log2_std_noise_gglwe_product, }; @@ -11,7 +11,7 @@ fn apply() { let log_n: usize = 8; let basek: usize = 12; let k_in: usize = 45; - let digits: usize = div_ceil(k_in, basek); + let digits: usize = k_in.div_ceil(basek); (1..4).for_each(|rank_in| { (1..4).for_each(|rank_out| { (1..digits + 1).for_each(|di| { @@ -32,7 +32,7 @@ fn apply_inplace() { let log_n: usize = 8; let basek: usize = 12; let k_ct: usize = 45; - let digits: usize = div_ceil(k_ct, basek); + let digits: usize = k_ct.div_ceil(basek); (1..4).for_each(|rank| { (1..digits + 1).for_each(|di| { let k_ksk: usize = k_ct + basek * di; @@ -55,7 +55,7 @@ fn test_keyswitch( ) { let module: Module = Module::::new(1 << log_n); - let rows: usize = div_ceil(k_in, basek * digits); + let rows: usize = k_in.div_ceil(basek * digits); let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); @@ -148,7 +148,7 @@ fn test_keyswitch( fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = div_ceil(k_ct, basek * digits); + let rows: usize = k_ct.div_ceil(basek * digits); let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); diff --git a/core/src/glwe/test_fft64/packing.rs b/core/src/glwe/test_fft64/packing.rs index da817e6..0e9ee71 100644 --- a/core/src/glwe/test_fft64/packing.rs +++ b/core/src/glwe/test_fft64/packing.rs @@ -1,4 +1,4 @@ -use crate::{FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWESecret, div_ceil}; +use crate::{FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWESecret}; use std::collections::HashMap; use backend::{Encoding, FFT64, Module, ScratchOwned, Stats}; diff --git a/core/src/glwe/test_fft64/trace.rs b/core/src/glwe/test_fft64/trace.rs index eae2bce..24b0a75 100644 --- a/core/src/glwe/test_fft64/trace.rs +++ b/core/src/glwe/test_fft64/trace.rs @@ -4,7 +4,7 @@ use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps, ZnxVie use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, div_ceil, + FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::var_noise_gglwe_product, }; diff --git a/core/src/lib.rs b/core/src/lib.rs index fa6d009..afdcc55 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -7,18 +7,17 @@ pub mod ggsw; pub mod glwe; pub mod lwe; pub mod noise; -mod utils; use backend::Backend; use backend::FFT64; use backend::Module; -pub use elem::*; +pub use elem::{GetRow, Infos, SetMetaData, SetRow}; pub use fourier_glwe::{FourierGLWECiphertext, FourierGLWESecret}; pub use gglwe::{GGLWECiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey}; -pub use ggsw::*; +pub use ggsw::GGSWCiphertext; pub use glwe::{GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWEPublicKey, GLWESecret}; pub(crate) use glwe::{GLWECiphertextToMut, GLWECiphertextToRef}; -pub use lwe::*; +pub use lwe::{LWECiphertext, LWESecret}; pub use backend::Scratch; pub use backend::ScratchOwned; @@ -174,7 +173,7 @@ impl ScratchCore for Scratch { k: usize, rank: usize, ) -> (FourierGLWECiphertext<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, div_ceil(k, basek)); + let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, k.div_ceil(basek)); (FourierGLWECiphertext { data, basek, k }, scratch) } diff --git a/core/src/lwe/ciphertext.rs b/core/src/lwe/ciphertext.rs new file mode 100644 index 0000000..432a866 --- /dev/null +++ b/core/src/lwe/ciphertext.rs @@ -0,0 +1,77 @@ +use backend::{VecZnx, VecZnxToMut, VecZnxToRef}; + +use crate::{Infos, SetMetaData}; + +pub struct LWECiphertext { + pub(crate) data: VecZnx, + pub(crate) k: usize, + pub(crate) basek: usize, +} + +impl LWECiphertext> { + pub fn alloc(n: usize, basek: usize, k: usize) -> Self { + Self { + data: VecZnx::new::(n, 1, k.div_ceil(basek)), + k: k, + basek: basek, + } + } +} + +impl Infos for LWECiphertext { + type Inner = VecZnx; + + fn n(&self) -> usize{ + &self.inner().n-1 + } + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl + AsRef<[u8]>> SetMetaData for LWECiphertext { + fn set_k(&mut self, k: usize) { + self.k = k + } + + fn set_basek(&mut self, basek: usize) { + self.basek = basek + } +} + +pub trait LWECiphertextToRef { + fn to_ref(&self) -> LWECiphertext<&[u8]>; +} + +impl> LWECiphertextToRef for LWECiphertext { + fn to_ref(&self) -> LWECiphertext<&[u8]> { + LWECiphertext { + data: self.data.to_ref(), + basek: self.basek, + k: self.k, + } + } +} + +pub trait LWECiphertextToMut { + fn to_mut(&mut self) -> LWECiphertext<&mut [u8]>; +} + +impl + AsRef<[u8]>> LWECiphertextToMut for LWECiphertext { + fn to_mut(&mut self) -> LWECiphertext<&mut [u8]> { + LWECiphertext { + data: self.data.to_mut(), + basek: self.basek, + k: self.k, + } + } +} diff --git a/core/src/lwe/decryption.rs b/core/src/lwe/decryption.rs new file mode 100644 index 0000000..5f7c201 --- /dev/null +++ b/core/src/lwe/decryption.rs @@ -0,0 +1,21 @@ +use backend::{alloc_aligned, ZnxView, ZnxViewMut}; + +use crate::{lwe::{LWEPlaintext}, Infos, LWECiphertext, LWESecret, SetMetaData}; + +impl LWECiphertext where DataSelf: AsRef<[u8]>{ + pub fn decrypt(&self, pt: &mut LWEPlaintext, sk: &LWESecret) where DataPt: AsRef<[u8]> + AsMut<[u8]>, DataSk: AsRef<[u8]>{ + #[cfg(debug_assertions)]{ + assert_eq!(self.n(), sk.n()); + } + + (0..pt.size().min(self.size())).for_each(|i|{ + pt.data.at_mut(0, i)[0] = self.data.at(0, i)[0] + self.data.at(0, i)[1..].iter().zip(sk.data.at(0, 0)).map(|(x, y)| x * y).sum::(); + }); + + let mut tmp_bytes: Vec = alloc_aligned(size_of::()); + pt.data.normalize(self.basek(), 0, &mut tmp_bytes); + + pt.set_basek(self.basek()); + pt.set_k(self.k().min(pt.size() * self.basek())); + } +} \ No newline at end of file diff --git a/core/src/lwe/encryption.rs b/core/src/lwe/encryption.rs new file mode 100644 index 0000000..5e993d9 --- /dev/null +++ b/core/src/lwe/encryption.rs @@ -0,0 +1,35 @@ +use backend::{alloc_aligned, AddNormal, FillUniform, VecZnx, ZnxView, ZnxViewMut}; +use sampling::source::Source; + +use crate::{lwe::LWEPlaintext, Infos, LWECiphertext, LWESecret, SIX_SIGMA}; + + + +impl LWECiphertext where DataSelf: AsMut<[u8]> + AsRef<[u8]>{ + pub fn encrypt_sk(&mut self, pt: &LWEPlaintext, sk: &LWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64) where DataPt: AsRef<[u8]>, DataSk: AsRef<[u8]>{ + + #[cfg(debug_assertions)]{ + assert_eq!(self.n(), sk.n()) + } + + let basek: usize = self.basek(); + + self.data.fill_uniform(basek, 0, self.size(), source_xa); + let mut tmp_znx: VecZnx> = VecZnx::>::new::(1, 1, self.size()); + + (0..self.size()).for_each(|i|{ + tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] - self.data.at(0, i)[1..].iter().zip(sk.data.at(0, 0)).map(|(x, y)| x * y).sum::(); + }); + + tmp_znx.add_normal(basek, 0, self.k(), source_xe, sigma, sigma*SIX_SIGMA); + + let mut tmp_bytes: Vec = alloc_aligned(size_of::()); + + tmp_znx.normalize(basek, 0, &mut tmp_bytes); + + (0..self.size()).for_each(|i|{ + self.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; + }); + + } +} diff --git a/core/src/lwe/mod.rs b/core/src/lwe/mod.rs index d91ce80..19a8362 100644 --- a/core/src/lwe/mod.rs +++ b/core/src/lwe/mod.rs @@ -1,3 +1,9 @@ +pub mod ciphertext; pub mod secret; +pub mod encryption; +pub mod decryption; +pub mod plaintext; +pub use ciphertext::LWECiphertext; pub use secret::LWESecret; +pub use plaintext::LWEPlaintext; diff --git a/core/src/lwe/plaintext.rs b/core/src/lwe/plaintext.rs new file mode 100644 index 0000000..a1c0482 --- /dev/null +++ b/core/src/lwe/plaintext.rs @@ -0,0 +1,73 @@ +use backend::{VecZnx, VecZnxToMut, VecZnxToRef}; + +use crate::{Infos, SetMetaData}; + +pub struct LWEPlaintext{ + pub(crate) data: VecZnx, + pub(crate) k: usize, + pub(crate) basek: usize, +} + +impl LWEPlaintext> { + pub fn alloc(basek: usize, k: usize) -> Self { + Self { + data: VecZnx::new::(1, 1, k.div_ceil(basek)), + k: k, + basek: basek, + } + } +} + +impl Infos for LWEPlaintext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl + AsRef<[u8]>> SetMetaData for LWEPlaintext { + fn set_k(&mut self, k: usize) { + self.k = k + } + + fn set_basek(&mut self, basek: usize) { + self.basek = basek + } +} + +pub trait LWEPlaintextToRef { + fn to_ref(&self) -> LWEPlaintext<&[u8]>; +} + +impl> LWEPlaintextToRef for LWEPlaintext { + fn to_ref(&self) -> LWEPlaintext<&[u8]> { + LWEPlaintext { + data: self.data.to_ref(), + basek: self.basek, + k: self.k, + } + } +} + +pub trait LWEPlaintextToMut { + fn to_mut(&mut self) -> LWEPlaintext<&mut [u8]>; +} + +impl + AsRef<[u8]>> LWEPlaintextToMut for LWEPlaintext { + fn to_mut(&mut self) -> LWEPlaintext<&mut [u8]> { + LWEPlaintext { + data: self.data.to_mut(), + basek: self.basek, + k: self.k, + } + } +} \ No newline at end of file From 6a006b442a03af203f26774e451a59751ed0ce8c Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 18 Jun 2025 22:23:32 +0200 Subject: [PATCH 09/23] working block binary BR --- core/src/blind_rotation/ccgi.rs | 26 ++++++----- core/src/blind_rotation/test_fft64/cggi.rs | 51 ++++++++++++++++------ core/src/lwe/ciphertext.rs | 2 +- 3 files changed, 53 insertions(+), 26 deletions(-) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 341045e..769ebb5 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -4,8 +4,7 @@ use backend::{MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, use itertools::izip; use crate::{ - GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEPlaintext, Infos, LWECiphertext, - ScratchCore, blind_rotation::key::BlindRotationKeyCGGI, lwe::ciphertext::LWECiphertextToRef, + blind_rotation::key::BlindRotationKeyCGGI, lwe::ciphertext::LWECiphertextToRef, FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, ScratchCore }; pub fn cggi_blind_rotate_scratch_space( @@ -34,8 +33,7 @@ pub fn cggi_blind_rotate( DataIn: AsRef<[u8]>, DataLUT: AsRef<[u8]>, { - - println!("{}", lwe.n()); + let basek = res.basek(); let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); @@ -60,7 +58,7 @@ pub fn cggi_blind_rotate( let (mut acc_dft, scratch1) = scratch.tmp_glwe_fourier(module, brk.basek(), out_mut.k(), out_mut.rank()); let (mut acc_add_dft, scratch2) = scratch1.tmp_glwe_fourier(module, brk.basek(), out_mut.k(), out_mut.rank()); - let (mut vmp_res, scratch3) = scratch2.tmp_vec_znx_dft(module, acc_dft.rank()+1, acc_dft.size()); + let (mut vmp_res, scratch3) = scratch2.tmp_glwe_fourier(module, basek, out_mut.k(), out_mut.rank()); let (mut xai_minus_one, scratch4) = scratch3.tmp_scalar_znx(module, 1); let (mut xai_minus_one_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); @@ -73,13 +71,13 @@ pub fn cggi_blind_rotate( out_mut.dft(module, &mut acc_dft); acc_add_dft.data.zero(); - + izip!(ai.iter(), ski.iter()) .enumerate() .for_each(|(i, (aii, skii))| { // vmp_res = DFT(acc) * BRK[i] - module.vmp_apply(&mut vmp_res, &acc_dft.data, &skii.data, scratch5); + module.vmp_apply(&mut vmp_res.data, &acc_dft.data, &skii.data, scratch5); // DFT(X^ai -1) xai_minus_one.zero(); @@ -90,19 +88,23 @@ pub fn cggi_blind_rotate( // DFT(X^ai -1) * (DFT(acc) * BRK[i]) (0..cols).for_each(|i|{ - module.svp_apply_inplace(&mut vmp_res, i, &xai_minus_one_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft.data, i, &vmp_res, i); + module.svp_apply_inplace(&mut vmp_res.data, i, &xai_minus_one_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft.data, i, &vmp_res.data, i); }); - }); - acc_add_dft.idft(module, &mut out_mut, scratch5); + (0..cols).for_each(|i|{ + module.vec_znx_dft_add_inplace(&mut acc_dft.data, i, &acc_add_dft.data, i); + }); + + acc_dft.idft(module, &mut out_mut, scratch5); + }); let duration: std::time::Duration = start.elapsed(); println!("external products: {} us", duration.as_micros()); } -fn mod_switch_2n(module: &Module, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { +pub(crate) fn mod_switch_2n(module: &Module, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { let basek: usize = lwe.basek(); let log2n: usize = module.log_n() + 1; diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index 8a9246f..746ec1e 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -1,24 +1,23 @@ -use core::time; use std::time::Instant; -use backend::{Encoding, Module, ScratchOwned, FFT64}; +use backend::{Encoding, Module, ScratchOwned, Stats, ZnxView, ZnxViewMut, FFT64}; use sampling::source::Source; use crate::{ - blind_rotation::{ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space}, key::BlindRotationKeyCGGI}, lwe::LWEPlaintext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, LWECiphertext, LWESecret + blind_rotation::{ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, mod_switch_2n}, key::BlindRotationKeyCGGI}, lwe::{ciphertext::{LWECiphertextToMut, LWECiphertextToRef}, LWEPlaintext}, FourierGLWESecret, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret }; #[test] fn blind_rotation() { let module: Module = Module::::new(2048); - let basek: usize = 17; + let basek: usize = 20; let n_lwe: usize = 1071; let k_lwe: usize = 22; - let k_brk: usize = 54; - let rows_brk: usize = 1; - let k_lut: usize = 44; + let k_brk: usize = 60; + let rows_brk: usize = 2; + let k_lut: usize = 60; let rank: usize = 1; let block_size: usize = 7; @@ -55,7 +54,10 @@ fn blind_rotation() { let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); - pt_lwe.data.encode_coeff_i64(0, basek, 7, 0, 63, 7); + let x: i64 = 0; + let bits: usize = 6; + + pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); println!("{}", pt_lwe.data); @@ -65,21 +67,44 @@ fn blind_rotation() { println!("{}", pt_lwe.data); - let lut: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); + let mut lut: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); + + + lut.data.at_mut(0, 0)[0] = 0; + (1..module.n()).for_each(|i|{ + lut.data.at_mut(0, 0)[i] = - ((module.n() as i64 - i as i64 - 1)<<(basek - module.log_n() - 1)); + }); + + + let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); let start: Instant = Instant::now(); - (0..32).for_each(|i|{ + (0..1).for_each(|i|{ cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch.borrow()); }); let duration: std::time::Duration = start.elapsed(); println!("blind-rotate: {} ms", duration.as_millis()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); - res.decrypt(&module , &mut pt, &sk_glwe_dft, scratch.borrow()); + res.decrypt(&module , &mut pt_have, &sk_glwe_dft, scratch.borrow()); - println!("{}", pt.data); + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + + mod_switch_2n(&module, &mut lwe_2n, &lwe.to_ref()); + + let pt_want: i64 = (lwe_2n[0] + lwe_2n[1..].iter().zip(sk_lwe.data.at(0, 0)).map(|(x, y)| x * y).sum::()) % (module.n() as i64 * 2); + + lut.rotate_inplace(&module, pt_want); + + lut.sub_inplace_ab(&module, &pt_have); + + let noise: f64 = lut.data.std(0, basek); + + println!("noise: {}", noise); + + } diff --git a/core/src/lwe/ciphertext.rs b/core/src/lwe/ciphertext.rs index 432a866..4b8c85d 100644 --- a/core/src/lwe/ciphertext.rs +++ b/core/src/lwe/ciphertext.rs @@ -11,7 +11,7 @@ pub struct LWECiphertext { impl LWECiphertext> { pub fn alloc(n: usize, basek: usize, k: usize) -> Self { Self { - data: VecZnx::new::(n, 1, k.div_ceil(basek)), + data: VecZnx::new::(n+1, 1, k.div_ceil(basek)), k: k, basek: basek, } From 4c1a84d7024de2745698122f700aecf1595183e9 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 19 Jun 2025 16:33:47 +0200 Subject: [PATCH 10/23] Added support for arbitrary extended LUT --- backend/src/lib.rs | 2 +- backend/src/scalar_znx.rs | 2 +- backend/src/vec_znx.rs | 5 +- core/src/blind_rotation/ccgi.rs | 58 +++++++------- core/src/blind_rotation/key.rs | 3 + core/src/blind_rotation/lut.rs | 84 ++++++++++++++++++++ core/src/blind_rotation/mod.rs | 1 + core/src/blind_rotation/test_fft64/cggi.rs | 61 ++++++++------ core/src/glwe/test_fft64/automorphism.rs | 3 +- core/src/glwe/test_fft64/external_product.rs | 4 +- core/src/glwe/test_fft64/keyswitch.rs | 3 +- core/src/glwe/test_fft64/trace.rs | 3 +- core/src/lwe/ciphertext.rs | 6 +- core/src/lwe/decryption.rs | 29 +++++-- core/src/lwe/encryption.rs | 41 +++++++--- core/src/lwe/mod.rs | 6 +- core/src/lwe/plaintext.rs | 4 +- 17 files changed, 219 insertions(+), 96 deletions(-) create mode 100644 core/src/blind_rotation/lut.rs diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 09e5556..d55ba8a 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -103,7 +103,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { assert_eq!( - (size * size_of::()) % (align/ size_of::()), + (size * size_of::()) % (align / size_of::()), 0, "size={} must be a multiple of align={}", size, diff --git a/backend/src/scalar_znx.rs b/backend/src/scalar_znx.rs index 0a5bb64..2cc1797 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/scalar_znx.rs @@ -93,7 +93,7 @@ impl + AsRef<[u8]>> ScalarZnx { pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) { assert!(self.n() % block_size == 0); let max_idx: u64 = (block_size + 1) as u64; - let mask_idx: u64 = (1<<((u64::BITS - max_idx.leading_zeros())as u64)) - 1 ; + let mask_idx: u64 = (1 << ((u64::BITS - max_idx.leading_zeros()) as u64)) - 1; for block in self.at_mut(col, 0).chunks_mut(block_size) { let idx: usize = source.next_u64n(max_idx, mask_idx) as usize; if idx != block_size { diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index 8213d5e..6189bad 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -243,13 +243,12 @@ fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } -impl + AsMut<[u8]>> VecZnx{ - pub fn normalize(&mut self, basek: usize, a_col: usize, tmp_bytes: &mut [u8]){ +impl + AsMut<[u8]>> VecZnx { + pub fn normalize(&mut self, basek: usize, a_col: usize, tmp_bytes: &mut [u8]) { normalize(basek, self, a_col, tmp_bytes); } } - fn normalize + AsRef<[u8]>>(basek: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 769ebb5..f560d09 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -1,10 +1,15 @@ use std::time::Instant; -use backend::{MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero, FFT64}; +use backend::{ + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, + ZnxZero, +}; use itertools::izip; use crate::{ - blind_rotation::key::BlindRotationKeyCGGI, lwe::ciphertext::LWECiphertextToRef, FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, ScratchCore + GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, ScratchCore, + blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, + lwe::ciphertext::LWECiphertextToRef, }; pub fn cggi_blind_rotate_scratch_space( @@ -21,26 +26,24 @@ pub fn cggi_blind_rotate_scratch_space( | GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_lut, k_brk, 1, rank)) } -pub fn cggi_blind_rotate( +pub fn cggi_blind_rotate( module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, - lut: &GLWEPlaintext, + lut: &LookUpTable, brk: &BlindRotationKeyCGGI, scratch: &mut Scratch, ) where DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, - DataLUT: AsRef<[u8]>, { let basek = res.basek(); let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); - let lut_ref: GLWECiphertext<&[u8]> = lut.to_ref(); - let cols = out_mut.rank()+1; + let cols: usize = out_mut.rank() + 1; mod_switch_2n(module, &mut lwe_2n, &lwe_ref); @@ -50,7 +53,7 @@ pub fn cggi_blind_rotate( out_mut.data.zero(); // Initialize out to X^{b} * LUT(X) - module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut_ref.data, 0); + module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); let block_size: usize = brk.block_size(); @@ -68,37 +71,32 @@ pub fn cggi_blind_rotate( brk.data.chunks_exact(block_size) ) .for_each(|(ai, ski)| { - out_mut.dft(module, &mut acc_dft); acc_add_dft.data.zero(); - izip!(ai.iter(), ski.iter()) - .enumerate() - .for_each(|(i, (aii, skii))| { + izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { + // vmp_res = DFT(acc) * BRK[i] + module.vmp_apply(&mut vmp_res.data, &acc_dft.data, &skii.data, scratch5); - // vmp_res = DFT(acc) * BRK[i] - module.vmp_apply(&mut vmp_res.data, &acc_dft.data, &skii.data, scratch5); + // DFT(X^ai -1) + xai_minus_one.zero(); + xai_minus_one.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(*aii, &mut xai_minus_one, 0); + xai_minus_one.at_mut(0, 0)[0] -= 1; + module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); - // DFT(X^ai -1) - xai_minus_one.zero(); - xai_minus_one.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(*aii, &mut xai_minus_one, 0); - xai_minus_one.at_mut(0, 0)[0] -= 1; - module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); - - // DFT(X^ai -1) * (DFT(acc) * BRK[i]) - (0..cols).for_each(|i|{ - module.svp_apply_inplace(&mut vmp_res.data, i, &xai_minus_one_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft.data, i, &vmp_res.data, i); - }); + // DFT(X^ai -1) * (DFT(acc) * BRK[i]) + (0..cols).for_each(|i| { + module.svp_apply_inplace(&mut vmp_res.data, i, &xai_minus_one_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft.data, i, &vmp_res.data, i); }); - - (0..cols).for_each(|i|{ + }); + + (0..cols).for_each(|i| { module.vec_znx_dft_add_inplace(&mut acc_dft.data, i, &acc_add_dft.data, i); }); - - acc_dft.idft(module, &mut out_mut, scratch5); + acc_dft.idft(module, &mut out_mut, scratch5); }); let duration: std::time::Duration = start.elapsed(); println!("external products: {} us", duration.as_micros()); diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index 9f23c61..8cb24cb 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -69,14 +69,17 @@ impl BlindRotationKeyCGGI { } } + #[allow(dead_code)] pub(crate) fn rows(&self) -> usize { self.data[0].rows() } + #[allow(dead_code)] pub(crate) fn k(&self) -> usize { self.data[0].k() } + #[allow(dead_code)] pub(crate) fn rank(&self) -> usize { self.data[0].rank() } diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs new file mode 100644 index 0000000..56b65a0 --- /dev/null +++ b/core/src/blind_rotation/lut.rs @@ -0,0 +1,84 @@ +use backend::{FFT64, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; + +pub struct LookUpTable { + pub(crate) data: Vec>>, + pub(crate) basek: usize, + pub(crate) k: usize, +} + +impl LookUpTable { + pub fn alloc(module: &Module, basek: usize, k: usize, extend_factor: usize) -> Self { + let size: usize = k.div_ceil(basek); + let mut data: Vec>> = Vec::with_capacity(extend_factor); + (0..extend_factor).for_each(|_| { + data.push(module.new_vec_znx(1, size)); + }); + Self { data, basek, k } + } + + pub fn set(&mut self, module: &Module, f: fn(i64) -> i64, message_modulus: usize) { + let basek: usize = self.basek; + + // Get the number minimum limb to store the message modulus + let limbs: usize = message_modulus.div_ceil(1 << basek); + + // Scaling factor + let scale: i64 = (1 << (basek * limbs - 1)).div_round(message_modulus) as i64; + + // Updates function + let f_scaled = |x: i64| (f(x) % message_modulus as i64) * scale; + + // If LUT size > module.n() + let domain_size: usize = self.data[0].n() * self.data.len(); + + let size: usize = self.k.div_ceil(self.basek); + + // Equivalent to AUTO([f(0), f(1), ..., f(n-1)], -1) + let mut lut_full: VecZnx> = VecZnx::new::(domain_size, 1, size); + { + let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); + + let start: usize = 0; + let end: usize = (domain_size).div_round(message_modulus); + + let y: i64 = f_scaled(0); + (start..end).for_each(|i| { + lut_at[i] = y; + }); + + (1..message_modulus).for_each(|x| { + let start: usize = (x * domain_size).div_round(message_modulus); + let end: usize = ((x + 1) * domain_size).div_round(message_modulus); + let y: i64 = f_scaled(x as i64); + (start..end).for_each(|i| { + lut_at[domain_size - i] = -y; + }) + }); + } + + // Rotates half the step to the left + let half_step: usize = domain_size.div_round(message_modulus << 1); + module.vec_znx_rotate_inplace(-(half_step as i64), &mut lut_full, 0); + + let mut tmp_bytes: Vec = alloc_aligned(lut_full.n() * size_of::()); + lut_full.normalize(self.basek, 0, &mut tmp_bytes); + + if self.data.len() > 1 { + let mut scratch: ScratchOwned = ScratchOwned::new(module.bytes_of_vec_znx(1, size)); + module.vec_znx_split(&mut self.data, 0, &lut_full, 0, scratch.borrow()); + } else { + module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); + } + } +} + +pub trait DivRound { + fn div_round(self, rhs: Self) -> Self; +} + +impl DivRound for usize { + #[inline] + fn div_round(self, rhs: Self) -> Self { + (self + rhs / 2) / rhs + } +} diff --git a/core/src/blind_rotation/mod.rs b/core/src/blind_rotation/mod.rs index 63fb3fd..1b3d5ff 100644 --- a/core/src/blind_rotation/mod.rs +++ b/core/src/blind_rotation/mod.rs @@ -1,6 +1,7 @@ // pub mod cggi; pub mod ccgi; pub mod key; +pub mod lut; #[cfg(test)] pub mod test_fft64; diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index 746ec1e..bb98cea 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -1,10 +1,16 @@ use std::time::Instant; -use backend::{Encoding, Module, ScratchOwned, Stats, ZnxView, ZnxViewMut, FFT64}; +use backend::{Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxView}; use sampling::source::Source; use crate::{ - blind_rotation::{ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, mod_switch_2n}, key::BlindRotationKeyCGGI}, lwe::{ciphertext::{LWECiphertextToMut, LWECiphertextToRef}, LWEPlaintext}, FourierGLWESecret, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret + FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, + blind_rotation::{ + ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, mod_switch_2n}, + key::BlindRotationKeyCGGI, + lut::LookUpTable, + }, + lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef}, }; #[test] @@ -21,6 +27,8 @@ fn blind_rotation() { let rank: usize = 1; let block_size: usize = 7; + let message_modulus: usize = 64; + let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); @@ -32,12 +40,14 @@ fn blind_rotation() { let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut scratch: ScratchOwned = ScratchOwned::new(BlindRotationKeyCGGI::generate_from_sk_scratch_space( - &module, basek, k_brk, rank, - ) | cggi_blind_rotate_scratch_space(&module, basek, k_lut, k_brk, rows_brk, rank)); + let mut scratch: ScratchOwned = ScratchOwned::new( + BlindRotationKeyCGGI::generate_from_sk_scratch_space(&module, basek, k_brk, rank) + | cggi_blind_rotate_scratch_space(&module, basek, k_lut, k_brk, rows_brk, rank), + ); let start: Instant = Instant::now(); let mut brk: BlindRotationKeyCGGI = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); + brk.generate_from_sk( &module, &sk_glwe_dft, @@ -47,6 +57,7 @@ fn blind_rotation() { 3.2, scratch.borrow(), ); + let duration: std::time::Duration = start.elapsed(); println!("brk-gen: {} ms", duration.as_millis()); @@ -67,44 +78,48 @@ fn blind_rotation() { println!("{}", pt_lwe.data); - let mut lut: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); - - - lut.data.at_mut(0, 0)[0] = 0; - (1..module.n()).for_each(|i|{ - lut.data.at_mut(0, 0)[i] = - ((module.n() as i64 - i as i64 - 1)<<(basek - module.log_n() - 1)); - }); - - + fn lut_fn(x: i64) -> i64 { + 2 * x + 1 + } + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, 1); + lut.set(&module, lut_fn, message_modulus); let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); let start: Instant = Instant::now(); - (0..1).for_each(|i|{ + (0..1).for_each(|_| { cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch.borrow()); }); - + let duration: std::time::Duration = start.elapsed(); println!("blind-rotate: {} ms", duration.as_millis()); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); - res.decrypt(&module , &mut pt_have, &sk_glwe_dft, scratch.borrow()); + res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); + + println!("pt_have: {}", pt_have.data); let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space mod_switch_2n(&module, &mut lwe_2n, &lwe.to_ref()); - let pt_want: i64 = (lwe_2n[0] + lwe_2n[1..].iter().zip(sk_lwe.data.at(0, 0)).map(|(x, y)| x * y).sum::()) % (module.n() as i64 * 2); + let pt_want: i64 = (lwe_2n[0] + + lwe_2n[1..] + .iter() + .zip(sk_lwe.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::()) + % (module.n() as i64 * 2); - lut.rotate_inplace(&module, pt_want); + module.vec_znx_rotate_inplace(pt_want, &mut lut.data[0], 0); - lut.sub_inplace_ab(&module, &pt_have); + println!("pt_want: {}", lut.data[0]); - let noise: f64 = lut.data.std(0, basek); + module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0); + + let noise: f64 = lut.data[0].std(0, basek); println!("noise: {}", noise); - - } diff --git a/core/src/glwe/test_fft64/automorphism.rs b/core/src/glwe/test_fft64/automorphism.rs index 1b46e63..ba739c6 100644 --- a/core/src/glwe/test_fft64/automorphism.rs +++ b/core/src/glwe/test_fft64/automorphism.rs @@ -3,8 +3,7 @@ use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, - noise::log2_std_noise_gglwe_product, + FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::log2_std_noise_gglwe_product, }; #[test] diff --git a/core/src/glwe/test_fft64/external_product.rs b/core/src/glwe/test_fft64/external_product.rs index aec23b6..e1f6b19 100644 --- a/core/src/glwe/test_fft64/external_product.rs +++ b/core/src/glwe/test_fft64/external_product.rs @@ -1,9 +1,7 @@ use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; use sampling::source::Source; -use crate::{ - FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::noise_ggsw_product, -}; +use crate::{FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::noise_ggsw_product}; #[test] fn apply() { diff --git a/core/src/glwe/test_fft64/keyswitch.rs b/core/src/glwe/test_fft64/keyswitch.rs index 732ef03..fb54204 100644 --- a/core/src/glwe/test_fft64/keyswitch.rs +++ b/core/src/glwe/test_fft64/keyswitch.rs @@ -2,8 +2,7 @@ use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, - noise::log2_std_noise_gglwe_product, + FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, noise::log2_std_noise_gglwe_product, }; #[test] diff --git a/core/src/glwe/test_fft64/trace.rs b/core/src/glwe/test_fft64/trace.rs index 24b0a75..fe4e1eb 100644 --- a/core/src/glwe/test_fft64/trace.rs +++ b/core/src/glwe/test_fft64/trace.rs @@ -4,8 +4,7 @@ use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps, ZnxVie use sampling::source::Source; use crate::{ - FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, - noise::var_noise_gglwe_product, + FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::var_noise_gglwe_product, }; #[test] diff --git a/core/src/lwe/ciphertext.rs b/core/src/lwe/ciphertext.rs index 4b8c85d..1e97eb4 100644 --- a/core/src/lwe/ciphertext.rs +++ b/core/src/lwe/ciphertext.rs @@ -11,7 +11,7 @@ pub struct LWECiphertext { impl LWECiphertext> { pub fn alloc(n: usize, basek: usize, k: usize) -> Self { Self { - data: VecZnx::new::(n+1, 1, k.div_ceil(basek)), + data: VecZnx::new::(n + 1, 1, k.div_ceil(basek)), k: k, basek: basek, } @@ -21,8 +21,8 @@ impl LWECiphertext> { impl Infos for LWECiphertext { type Inner = VecZnx; - fn n(&self) -> usize{ - &self.inner().n-1 + fn n(&self) -> usize { + &self.inner().n - 1 } fn inner(&self) -> &Self::Inner { diff --git a/core/src/lwe/decryption.rs b/core/src/lwe/decryption.rs index 5f7c201..3ed9d2b 100644 --- a/core/src/lwe/decryption.rs +++ b/core/src/lwe/decryption.rs @@ -1,15 +1,28 @@ -use backend::{alloc_aligned, ZnxView, ZnxViewMut}; +use backend::{ZnxView, ZnxViewMut, alloc_aligned}; -use crate::{lwe::{LWEPlaintext}, Infos, LWECiphertext, LWESecret, SetMetaData}; +use crate::{Infos, LWECiphertext, LWESecret, SetMetaData, lwe::LWEPlaintext}; -impl LWECiphertext where DataSelf: AsRef<[u8]>{ - pub fn decrypt(&self, pt: &mut LWEPlaintext, sk: &LWESecret) where DataPt: AsRef<[u8]> + AsMut<[u8]>, DataSk: AsRef<[u8]>{ - #[cfg(debug_assertions)]{ +impl LWECiphertext +where + DataSelf: AsRef<[u8]>, +{ + pub fn decrypt(&self, pt: &mut LWEPlaintext, sk: &LWESecret) + where + DataPt: AsRef<[u8]> + AsMut<[u8]>, + DataSk: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { assert_eq!(self.n(), sk.n()); } - (0..pt.size().min(self.size())).for_each(|i|{ - pt.data.at_mut(0, i)[0] = self.data.at(0, i)[0] + self.data.at(0, i)[1..].iter().zip(sk.data.at(0, 0)).map(|(x, y)| x * y).sum::(); + (0..pt.size().min(self.size())).for_each(|i| { + pt.data.at_mut(0, i)[0] = self.data.at(0, i)[0] + + self.data.at(0, i)[1..] + .iter() + .zip(sk.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::(); }); let mut tmp_bytes: Vec = alloc_aligned(size_of::()); @@ -18,4 +31,4 @@ impl LWECiphertext where DataSelf: AsRef<[u8]>{ pt.set_basek(self.basek()); pt.set_k(self.k().min(pt.size() * self.basek())); } -} \ No newline at end of file +} diff --git a/core/src/lwe/encryption.rs b/core/src/lwe/encryption.rs index 5e993d9..148e5c4 100644 --- a/core/src/lwe/encryption.rs +++ b/core/src/lwe/encryption.rs @@ -1,14 +1,25 @@ -use backend::{alloc_aligned, AddNormal, FillUniform, VecZnx, ZnxView, ZnxViewMut}; +use backend::{AddNormal, FillUniform, VecZnx, ZnxView, ZnxViewMut, alloc_aligned}; use sampling::source::Source; -use crate::{lwe::LWEPlaintext, Infos, LWECiphertext, LWESecret, SIX_SIGMA}; +use crate::{Infos, LWECiphertext, LWESecret, SIX_SIGMA, lwe::LWEPlaintext}; - - -impl LWECiphertext where DataSelf: AsMut<[u8]> + AsRef<[u8]>{ - pub fn encrypt_sk(&mut self, pt: &LWEPlaintext, sk: &LWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64) where DataPt: AsRef<[u8]>, DataSk: AsRef<[u8]>{ - - #[cfg(debug_assertions)]{ +impl LWECiphertext +where + DataSelf: AsMut<[u8]> + AsRef<[u8]>, +{ + pub fn encrypt_sk( + &mut self, + pt: &LWEPlaintext, + sk: &LWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + ) where + DataPt: AsRef<[u8]>, + DataSk: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { assert_eq!(self.n(), sk.n()) } @@ -17,19 +28,23 @@ impl LWECiphertext where DataSelf: AsMut<[u8]> + AsRef<[u8]> self.data.fill_uniform(basek, 0, self.size(), source_xa); let mut tmp_znx: VecZnx> = VecZnx::>::new::(1, 1, self.size()); - (0..self.size()).for_each(|i|{ - tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] - self.data.at(0, i)[1..].iter().zip(sk.data.at(0, 0)).map(|(x, y)| x * y).sum::(); + (0..self.size()).for_each(|i| { + tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] + - self.data.at(0, i)[1..] + .iter() + .zip(sk.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::(); }); - tmp_znx.add_normal(basek, 0, self.k(), source_xe, sigma, sigma*SIX_SIGMA); + tmp_znx.add_normal(basek, 0, self.k(), source_xe, sigma, sigma * SIX_SIGMA); let mut tmp_bytes: Vec = alloc_aligned(size_of::()); tmp_znx.normalize(basek, 0, &mut tmp_bytes); - (0..self.size()).for_each(|i|{ + (0..self.size()).for_each(|i| { self.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; }); - } } diff --git a/core/src/lwe/mod.rs b/core/src/lwe/mod.rs index 19a8362..b7bb7ed 100644 --- a/core/src/lwe/mod.rs +++ b/core/src/lwe/mod.rs @@ -1,9 +1,9 @@ pub mod ciphertext; -pub mod secret; -pub mod encryption; pub mod decryption; +pub mod encryption; pub mod plaintext; +pub mod secret; pub use ciphertext::LWECiphertext; -pub use secret::LWESecret; pub use plaintext::LWEPlaintext; +pub use secret::LWESecret; diff --git a/core/src/lwe/plaintext.rs b/core/src/lwe/plaintext.rs index a1c0482..7c73351 100644 --- a/core/src/lwe/plaintext.rs +++ b/core/src/lwe/plaintext.rs @@ -2,7 +2,7 @@ use backend::{VecZnx, VecZnxToMut, VecZnxToRef}; use crate::{Infos, SetMetaData}; -pub struct LWEPlaintext{ +pub struct LWEPlaintext { pub(crate) data: VecZnx, pub(crate) k: usize, pub(crate) basek: usize, @@ -70,4 +70,4 @@ impl + AsRef<[u8]>> LWEPlaintextToMut for LWEPlaintext { k: self.k, } } -} \ No newline at end of file +} From 52154d6f8a85ffbf6ad08b7b327df7edff9d7d53 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 20 Jun 2025 11:00:04 +0200 Subject: [PATCH 11/23] wip CGGI BR for extended LUT --- core/src/blind_rotation/ccgi.rs | 239 ++++++++++++++++++++- core/src/blind_rotation/lut.rs | 8 +- core/src/blind_rotation/test_fft64/cggi.rs | 2 +- 3 files changed, 239 insertions(+), 10 deletions(-) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index f560d09..38c8bf2 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -1,13 +1,13 @@ use std::time::Instant; use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, - ZnxZero, + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxDftOps, + VecZnxOps, ZnxView, ZnxViewMut, ZnxZero, }; use itertools::izip; use crate::{ - GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, ScratchCore, + FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, ScratchCore, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, lwe::ciphertext::LWECiphertextToRef, }; @@ -37,7 +37,232 @@ pub fn cggi_blind_rotate( DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, { - let basek = res.basek(); + if lut.data.len() > 1 { + cggi_blind_rotate_block_binary_exnteded(module, res, lwe, lut, brk, scratch); + } else if brk.block_size() > 1 { + cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); + } else { + todo!("implement this case") + } +} + +pub(crate) fn cggi_blind_rotate_block_binary_exnteded( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, +{ + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); + let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let basek: usize = out_mut.basek(); + + let cols: usize = out_mut.rank() + 1; + + mod_switch_2n( + 2 * module.n() * lut.extension_factor(), + &mut lwe_2n, + &lwe_ref, + ); + + let extension_factor: i64 = lut.extension_factor() as i64; + + let mut acc: Vec>> = Vec::with_capacity(lut.extension_factor()); + + for _ in 0..extension_factor { + acc.push(GLWECiphertext::alloc( + module, + basek, + out_mut.k(), + out_mut.rank(), + )); + } + + let a: &[i64] = &lwe_2n[1..]; + let b: i64 = lwe_2n[0]; + + let b_inner: i64 = b / extension_factor; + let b_outer: i64 = b % extension_factor; + + for (i, j) in (0..b_outer).zip(extension_factor - b_outer..extension_factor) { + module.vec_znx_rotate( + b_inner + 1, + &mut acc[j as usize].data, + 0, + &lut.data[i as usize], + 0, + ); + } + for (i, j) in (b_outer..extension_factor).zip(0..extension_factor - b_outer) { + module.vec_znx_rotate( + b_inner, + &mut acc[j as usize].data, + 0, + &lut.data[i as usize], + 0, + ); + } + + let block_size: usize = brk.block_size(); + + let mut acc_dft: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); + + for _ in 0..extension_factor { + acc_dft.push(FourierGLWECiphertext::alloc( + module, + basek, + out_mut.k(), + out_mut.rank(), + )); + } + + let mut vmp_res: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); + + for _ in 0..extension_factor { + vmp_res.push(FourierGLWECiphertext::alloc( + module, + basek, + out_mut.k(), + out_mut.rank(), + )); + } + + let mut acc_add_dft: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); + + for _ in 0..extension_factor { + acc_add_dft.push(FourierGLWECiphertext::alloc( + module, + basek, + out_mut.k(), + out_mut.rank(), + )); + } + + let mut xai_minus_one: backend::ScalarZnx> = module.new_scalar_znx(1); + let mut xai_minus_one_dft: backend::ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); + + izip!( + a.chunks_exact(block_size), + brk.data.chunks_exact(block_size) + ) + .enumerate() + .for_each(|(i, (ai, ski))| { + (0..lut.extension_factor()).for_each(|i| { + acc[i].dft(module, &mut acc_dft[i]); + acc_add_dft[i].data.zero(); + }); + + izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { + let aii_inner: i64 = aii / extension_factor; + let aii_outer: i64 = aii % extension_factor; + + // vmp_res = DFT(acc) * BRK[i] + (0..lut.extension_factor()).for_each(|i| { + module.vmp_apply(&mut vmp_res[i].data, &acc_dft[i].data, &skii.data, scratch); + }); + + if aii_outer == 0 { + xai_minus_one.zero(); + xai_minus_one.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0); + xai_minus_one.at_mut(0, 0)[0] -= 1; + module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + + (0..lut.extension_factor()).for_each(|j| { + (0..cols).for_each(|i| { + module.svp_apply_inplace(&mut vmp_res[j].data, i, &xai_minus_one_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[j].data, i, &vmp_res[j].data, i); + }); + }) + } else { + xai_minus_one.zero(); + xai_minus_one.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(aii_inner + 1, &mut xai_minus_one, 0); + xai_minus_one.at_mut(0, 0)[0] -= 1; + module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + + for (i, j) in (0..aii_outer).zip(extension_factor - aii_outer..extension_factor) { + module.vec_znx_rotate( + b_inner + 1, + &mut acc[j as usize].data, + 0, + &lut.data[i as usize], + 0, + ); + + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0); + module.vec_znx_dft_add_inplace( + &mut acc_add_dft[j as usize].data, + k, + &vmp_res[i as usize].data, + k, + ); + }); + } + + xai_minus_one.zero(); + xai_minus_one.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0); + xai_minus_one.at_mut(0, 0)[0] -= 1; + module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + + for (i, j) in (aii_outer..extension_factor).zip(0..extension_factor - aii_outer) { + module.vec_znx_rotate( + b_inner, + &mut acc[j as usize].data, + 0, + &lut.data[i as usize], + 0, + ); + + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0); + module.vec_znx_dft_add_inplace( + &mut acc_add_dft[j as usize].data, + k, + &vmp_res[i as usize].data, + k, + ); + }); + } + } + }); + + if i == lwe.n() - block_size { + (0..cols).for_each(|i| { + module.vec_znx_dft_add_inplace(&mut acc_dft[0].data, i, &acc_add_dft[0].data, i); + }); + acc_dft[0].idft(module, &mut out_mut, scratch); + } else { + (0..lut.extension_factor()).for_each(|j| { + (0..cols).for_each(|i| { + module.vec_znx_dft_add_inplace(&mut acc_dft[j].data, i, &acc_add_dft[j].data, i); + }); + + acc_dft[j].idft(module, &mut acc[j], scratch); + }) + } + }); +} + +pub(crate) fn cggi_blind_rotate_block_binary( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, +{ + let basek: usize = res.basek(); let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); @@ -45,7 +270,7 @@ pub fn cggi_blind_rotate( let cols: usize = out_mut.rank() + 1; - mod_switch_2n(module, &mut lwe_2n, &lwe_ref); + mod_switch_2n(2 * module.n(), &mut lwe_2n, &lwe_ref); let a: &[i64] = &lwe_2n[1..]; let b: i64 = lwe_2n[0]; @@ -102,10 +327,10 @@ pub fn cggi_blind_rotate( println!("external products: {} us", duration.as_micros()); } -pub(crate) fn mod_switch_2n(module: &Module, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { +pub(crate) fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { let basek: usize = lwe.basek(); - let log2n: usize = module.log_n() + 1; + let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; res.copy_from_slice(&lwe.data.at(0, 0)); diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index 56b65a0..743036b 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -16,6 +16,10 @@ impl LookUpTable { Self { data, basek, k } } + pub fn extension_factor(&self) -> usize { + self.data.len() + } + pub fn set(&mut self, module: &Module, f: fn(i64) -> i64, message_modulus: usize) { let basek: usize = self.basek; @@ -29,7 +33,7 @@ impl LookUpTable { let f_scaled = |x: i64| (f(x) % message_modulus as i64) * scale; // If LUT size > module.n() - let domain_size: usize = self.data[0].n() * self.data.len(); + let domain_size: usize = self.data[0].n() * self.extension_factor(); let size: usize = self.k.div_ceil(self.basek); @@ -63,7 +67,7 @@ impl LookUpTable { let mut tmp_bytes: Vec = alloc_aligned(lut_full.n() * size_of::()); lut_full.normalize(self.basek, 0, &mut tmp_bytes); - if self.data.len() > 1 { + if self.extension_factor() > 1 { let mut scratch: ScratchOwned = ScratchOwned::new(module.bytes_of_vec_znx(1, size)); module.vec_znx_split(&mut self.data, 0, &lut_full, 0, scratch.borrow()); } else { diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index bb98cea..6322c75 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -103,7 +103,7 @@ fn blind_rotation() { let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space - mod_switch_2n(&module, &mut lwe_2n, &lwe.to_ref()); + mod_switch_2n(module.n() * 2, &mut lwe_2n, &lwe.to_ref()); let pt_want: i64 = (lwe_2n[0] + lwe_2n[1..] From c98bf75b61d9edecb9487728d3200d4ead8ecdca Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 2 Jul 2025 12:25:22 +0200 Subject: [PATCH 12/23] Fixed lut & added test for lut --- backend/src/vec_znx.rs | 10 ++++ core/src/blind_rotation/lut.rs | 52 +++++++++++++---- core/src/blind_rotation/test_fft64/lut.rs | 69 +++++++++++++++++++++++ core/src/blind_rotation/test_fft64/mod.rs | 1 + 4 files changed, 121 insertions(+), 11 deletions(-) create mode 100644 core/src/blind_rotation/test_fft64/lut.rs diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index 6189bad..00568dd 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -111,6 +111,16 @@ impl + AsRef<[u8]>> VecZnx { } } + pub fn rotate(&mut self, k: i64){ + unsafe{ + (0..self.cols()).for_each(|i|{ + (0..self.size()).for_each(|j|{ + znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j)); + }); + }) + } + } + pub fn rsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) { let n: usize = self.n(); let cols: usize = self.cols(); diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index 743036b..b6e9a7f 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -1,4 +1,4 @@ -use backend::{FFT64, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; +use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; pub struct LookUpTable { pub(crate) data: Vec>>, @@ -7,10 +7,10 @@ pub struct LookUpTable { } impl LookUpTable { - pub fn alloc(module: &Module, basek: usize, k: usize, extend_factor: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self { let size: usize = k.div_ceil(basek); - let mut data: Vec>> = Vec::with_capacity(extend_factor); - (0..extend_factor).for_each(|_| { + let mut data: Vec>> = Vec::with_capacity(extension_factor); + (0..extension_factor).for_each(|_| { data.push(module.new_vec_znx(1, size)); }); Self { data, basek, k } @@ -20,6 +20,10 @@ impl LookUpTable { self.data.len() } + pub fn domain_size(&self) -> usize { + self.data.len() * self.data[0].n() + } + pub fn set(&mut self, module: &Module, f: fn(i64) -> i64, message_modulus: usize) { let basek: usize = self.basek; @@ -33,11 +37,11 @@ impl LookUpTable { let f_scaled = |x: i64| (f(x) % message_modulus as i64) * scale; // If LUT size > module.n() - let domain_size: usize = self.data[0].n() * self.extension_factor(); + let domain_size: usize = self.domain_size(); let size: usize = self.k.div_ceil(self.basek); - // Equivalent to AUTO([f(0), f(1), ..., f(n-1)], -1) + // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) let mut lut_full: VecZnx> = VecZnx::new::(domain_size, 1, size); { let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); @@ -55,28 +59,54 @@ impl LookUpTable { let end: usize = ((x + 1) * domain_size).div_round(message_modulus); let y: i64 = f_scaled(x as i64); (start..end).for_each(|i| { - lut_at[domain_size - i] = -y; + lut_at[i] = y; }) }); } // Rotates half the step to the left let half_step: usize = domain_size.div_round(message_modulus << 1); - module.vec_znx_rotate_inplace(-(half_step as i64), &mut lut_full, 0); + + lut_full.rotate(-(half_step as i64)); let mut tmp_bytes: Vec = alloc_aligned(lut_full.n() * size_of::()); lut_full.normalize(self.basek, 0, &mut tmp_bytes); if self.extension_factor() > 1 { - let mut scratch: ScratchOwned = ScratchOwned::new(module.bytes_of_vec_znx(1, size)); - module.vec_znx_split(&mut self.data, 0, &lut_full, 0, scratch.borrow()); + (0..self.extension_factor()).for_each(|i| { + module.switch_degree(&mut self.data[i], 0, &lut_full, 0); + if i < self.extension_factor() { + lut_full.rotate(-1); + } + }); } else { module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); } } + + pub(crate) fn rotate(&mut self, k: i64) { + let extension_factor: usize = self.extension_factor(); + let two_n: usize = 2 * self.data[0].n(); + let two_n_ext: usize = two_n * extension_factor; + + let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize; + + let k_hi: usize = k_pos / extension_factor; + let k_lo: usize = k_pos % extension_factor; + + (0..extension_factor - k_lo).for_each(|i| { + self.data[i].rotate(k_hi as i64); + }); + + (extension_factor - k_lo..extension_factor).for_each(|i| { + self.data[i].rotate(k_hi as i64 + 1); + }); + + self.data.rotate_right(k_lo as usize); + } } -pub trait DivRound { +pub(crate) trait DivRound { fn div_round(self, rhs: Self) -> Self; } diff --git a/core/src/blind_rotation/test_fft64/lut.rs b/core/src/blind_rotation/test_fft64/lut.rs new file mode 100644 index 0000000..9377d76 --- /dev/null +++ b/core/src/blind_rotation/test_fft64/lut.rs @@ -0,0 +1,69 @@ +use backend::{FFT64, Module, ZnxView}; + +use crate::blind_rotation::lut::{DivRound, LookUpTable}; + +#[test] +fn standard() { + let module: Module = Module::::new(32); + let basek: usize = 20; + let k_lut: usize = 40; + let message_modulus: usize = 16; + let extension_factor: usize = 1; + + let scale: usize = (1 << (basek - 1)) / message_modulus; + + fn lut_fn(x: i64) -> i64 { + x - 8 + } + + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); + lut.set(&module, lut_fn, message_modulus); + + let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; + lut.rotate(half_step); + + let step: usize = lut.domain_size().div_round(message_modulus); + + (0..lut.domain_size()).step_by(step).for_each(|i| { + (0..step).for_each(|j| { + assert_eq!( + lut_fn((i / step) as i64) % message_modulus as i64, + lut.data[0].raw()[0] / scale as i64 + ); + lut.rotate(-1); + }); + }); +} + +#[test] +fn extended() { + let module: Module = Module::::new(32); + let basek: usize = 20; + let k_lut: usize = 40; + let message_modulus: usize = 16; + let extension_factor: usize = 4; + + let scale: usize = (1 << (basek - 1)) / message_modulus; + + fn lut_fn(x: i64) -> i64 { + x - 8 + } + + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); + lut.set(&module, lut_fn, message_modulus); + + let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; + lut.rotate(half_step); + + let step: usize = lut.domain_size().div_round(message_modulus); + + (0..lut.domain_size()).step_by(step).for_each(|i| { + (0..step).for_each(|j| { + assert_eq!( + lut_fn((i / step) as i64) % message_modulus as i64, + lut.data[0].raw()[0] / scale as i64 + ); + lut.rotate(-1); + }); + }); +} diff --git a/core/src/blind_rotation/test_fft64/mod.rs b/core/src/blind_rotation/test_fft64/mod.rs index 1a23dff..18ac93c 100644 --- a/core/src/blind_rotation/test_fft64/mod.rs +++ b/core/src/blind_rotation/test_fft64/mod.rs @@ -1 +1,2 @@ pub mod cggi; +pub mod lut; From 81fb7101651e53c54b2ddc3c58e69abc2b9de711 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 3 Jul 2025 11:38:25 +0200 Subject: [PATCH 13/23] wip on extended br + bug fixing --- backend/src/scalar_znx.rs | 2 +- core/src/blind_rotation/ccgi.rs | 319 ++++++++++----------- core/src/blind_rotation/key.rs | 14 +- core/src/blind_rotation/lut.rs | 27 +- core/src/blind_rotation/test_fft64/cggi.rs | 61 ++-- core/src/blind_rotation/test_fft64/lut.rs | 4 +- core/src/gglwe/automorphism.rs | 6 +- core/src/gglwe/encryption.rs | 2 +- core/src/gglwe/external_product.rs | 6 +- core/src/gglwe/keyswitch.rs | 6 +- core/src/ggsw/ciphertext.rs | 14 +- core/src/lib.rs | 58 +++- 12 files changed, 303 insertions(+), 216 deletions(-) diff --git a/backend/src/scalar_znx.rs b/backend/src/scalar_znx.rs index 2cc1797..4acedb5 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/scalar_znx.rs @@ -109,7 +109,7 @@ impl>> ScalarZnx { } pub fn new(n: usize, cols: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of(n, cols)); + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols)); Self { data: data.into(), n, diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 38c8bf2..87471e9 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -1,29 +1,46 @@ use std::time::Instant; use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxDftOps, - VecZnxOps, ZnxView, ZnxViewMut, ZnxZero, + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + Scratch, VecZnxBigAlloc, VecZnxBigOps, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, }; use itertools::izip; use crate::{ - FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, ScratchCore, + FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWECiphertextToMut, GLWEPlaintext, Infos, LWECiphertext, + ScratchCore, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, lwe::ciphertext::LWECiphertextToRef, }; pub fn cggi_blind_rotate_scratch_space( module: &Module, + extension_factor: usize, basek: usize, k_lut: usize, k_brk: usize, rows: usize, rank: usize, ) -> usize { - let size = k_brk.div_ceil(basek); - GGSWCiphertext::, FFT64>::bytes_of(module, basek, k_brk, rows, 1, rank) - + (module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, rank + 1) - | GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_lut, k_brk, 1, rank)) + let lut_size: usize = k_lut.div_ceil(basek); + let brk_size: usize = k_brk.div_ceil(basek); + + let acc_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor; + let acc_big: usize = module.bytes_of_vec_znx_big(rank + 1, brk_size); + let acc_dft_add: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor; + let vmp_res: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor; + let xai_plus_y: usize = module.bytes_of_scalar_znx(1); + let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); + let vmp: usize = module.vmp_apply_tmp_bytes(lut_size, lut_size, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + + let acc: usize; + if extension_factor > 1 { + acc = GLWECiphertext::bytes_of(module, basek, k_lut, rank) * extension_factor; + } else { + acc = 0; + } + + return acc + acc_dft + acc_dft_add + vmp_res + xai_plus_y + xai_plus_y_dft + (vmp | acc_big); } pub fn cggi_blind_rotate( @@ -37,8 +54,8 @@ pub fn cggi_blind_rotate( DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, { - if lut.data.len() > 1 { - cggi_blind_rotate_block_binary_exnteded(module, res, lwe, lut, brk, scratch); + if lut.extension_factor() > 1 { + cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch); } else if brk.block_size() > 1 { cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); } else { @@ -46,7 +63,7 @@ pub fn cggi_blind_rotate( } } -pub(crate) fn cggi_blind_rotate_block_binary_exnteded( +pub(crate) fn cggi_blind_rotate_block_binary_extended( module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, @@ -57,198 +74,164 @@ pub(crate) fn cggi_blind_rotate_block_binary_exnteded( DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, { + let extension_factor: usize = lut.extension_factor(); + let basek: usize = res.basek(); + + let (mut acc, scratch1) = scratch.tmp_vec_glwe_ct(extension_factor, module, basek, res.k(), res.rank()); + let (mut acc_dft, scratch2) = scratch1.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank()); + let (mut vmp_res, scratch3) = scratch2.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank()); + let (mut acc_add_dft, scratch4) = scratch3.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank()); + + (0..extension_factor).for_each(|i| { + acc[i].data.zero(); + }); + + let (mut xai_plus_y, scratch5) = scratch4.tmp_scalar_znx(module, 1); + let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1); + let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size()); + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space - let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); - let basek: usize = out_mut.basek(); - let cols: usize = out_mut.rank() + 1; + let two_n_ext: usize = 2 * lut.domain_size(); - mod_switch_2n( - 2 * module.n() * lut.extension_factor(), - &mut lwe_2n, - &lwe_ref, - ); + let cols: usize = res.rank() + 1; - let extension_factor: i64 = lut.extension_factor() as i64; - - let mut acc: Vec>> = Vec::with_capacity(lut.extension_factor()); - - for _ in 0..extension_factor { - acc.push(GLWECiphertext::alloc( - module, - basek, - out_mut.k(), - out_mut.rank(), - )); - } + negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref); let a: &[i64] = &lwe_2n[1..]; - let b: i64 = lwe_2n[0]; + let b_pos: usize = ((lwe_2n[0] + two_n_ext as i64) % two_n_ext as i64) as usize; - let b_inner: i64 = b / extension_factor; - let b_outer: i64 = b % extension_factor; + let b_hi: usize = b_pos / extension_factor; + let b_lo: usize = b_pos % extension_factor; - for (i, j) in (0..b_outer).zip(extension_factor - b_outer..extension_factor) { - module.vec_znx_rotate( - b_inner + 1, - &mut acc[j as usize].data, - 0, - &lut.data[i as usize], - 0, - ); + for (i, j) in (0..b_lo).zip(extension_factor - b_lo..extension_factor) { + module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i].data, 0, &lut.data[j], 0); } - for (i, j) in (b_outer..extension_factor).zip(0..extension_factor - b_outer) { - module.vec_znx_rotate( - b_inner, - &mut acc[j as usize].data, - 0, - &lut.data[i as usize], - 0, - ); + for (i, j) in (b_lo..extension_factor).zip(0..extension_factor - b_lo) { + module.vec_znx_rotate(b_hi as i64, &mut acc[i].data, 0, &lut.data[j], 0); } let block_size: usize = brk.block_size(); - let mut acc_dft: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); - - for _ in 0..extension_factor { - acc_dft.push(FourierGLWECiphertext::alloc( - module, - basek, - out_mut.k(), - out_mut.rank(), - )); - } - - let mut vmp_res: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); - - for _ in 0..extension_factor { - vmp_res.push(FourierGLWECiphertext::alloc( - module, - basek, - out_mut.k(), - out_mut.rank(), - )); - } - - let mut acc_add_dft: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); - - for _ in 0..extension_factor { - acc_add_dft.push(FourierGLWECiphertext::alloc( - module, - basek, - out_mut.k(), - out_mut.rank(), - )); - } - - let mut xai_minus_one: backend::ScalarZnx> = module.new_scalar_znx(1); - let mut xai_minus_one_dft: backend::ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); - izip!( a.chunks_exact(block_size), brk.data.chunks_exact(block_size) ) .enumerate() .for_each(|(i, (ai, ski))| { - (0..lut.extension_factor()).for_each(|i| { - acc[i].dft(module, &mut acc_dft[i]); + (0..extension_factor).for_each(|i| { + (0..cols).for_each(|j| { + module.vec_znx_dft(1, 0, &mut acc_dft[i].data, j, &acc[i].data, j); + }); acc_add_dft[i].data.zero(); }); + // TODO: first & last iterations can be optimized izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { - let aii_inner: i64 = aii / extension_factor; - let aii_outer: i64 = aii % extension_factor; + let ai_pos: usize = ((aii + two_n_ext as i64) % two_n_ext as i64) as usize; + let ai_hi: usize = ai_pos / extension_factor; + let ai_lo: usize = ai_pos % extension_factor; // vmp_res = DFT(acc) * BRK[i] - (0..lut.extension_factor()).for_each(|i| { - module.vmp_apply(&mut vmp_res[i].data, &acc_dft[i].data, &skii.data, scratch); + (0..extension_factor).for_each(|i| { + module.vmp_apply(&mut vmp_res[i].data, &acc_dft[i].data, &skii.data, scratch7); }); - if aii_outer == 0 { - xai_minus_one.zero(); - xai_minus_one.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0); - xai_minus_one.at_mut(0, 0)[0] -= 1; - module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + // Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1) + if ai_lo == 0 { + // DFT X^{-ai} + set_xai_plus_y( + module, + ai_hi as i64, + -1, + &mut xai_plus_y_dft, + &mut xai_plus_y, + ); - (0..lut.extension_factor()).for_each(|j| { + // Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1) + (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { - module.svp_apply_inplace(&mut vmp_res[j].data, i, &xai_minus_one_dft, 0); + module.svp_apply_inplace(&mut vmp_res[j].data, i, &xai_plus_y_dft, 0); module.vec_znx_dft_add_inplace(&mut acc_add_dft[j].data, i, &vmp_res[j].data, i); }); - }) + }); + // Non trivial case: rotation between polynomials + // In this case we can't directly multiply with (X^{-ai} - 1) because of the + // ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the + // computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai} } else { - xai_minus_one.zero(); - xai_minus_one.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(aii_inner + 1, &mut xai_minus_one, 0); - xai_minus_one.at_mut(0, 0)[0] -= 1; - module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); - - for (i, j) in (0..aii_outer).zip(extension_factor - aii_outer..extension_factor) { - module.vec_znx_rotate( - b_inner + 1, - &mut acc[j as usize].data, - 0, - &lut.data[i as usize], - 0, - ); - + // Sets acc_add_dft[i] = acc[i] * sk + (0..extension_factor).for_each(|i| { (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0); - module.vec_znx_dft_add_inplace( - &mut acc_add_dft[j as usize].data, - k, - &vmp_res[i as usize].data, - k, - ); + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i].data, k, &vmp_res[i].data, k); + }) + }); + + // DFT X^{-ai+1} + set_xai_plus_y( + module, + ai_hi as i64 + 1, + 0, + &mut xai_plus_y_dft, + &mut xai_plus_y, + ); + + // Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1} + for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) { + module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i].data, 0, &lut.data[j], 0); + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[j].data, k, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i].data, k, &vmp_res[j].data, k); }); } - xai_minus_one.zero(); - xai_minus_one.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0); - xai_minus_one.at_mut(0, 0)[0] -= 1; - module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); - - for (i, j) in (aii_outer..extension_factor).zip(0..extension_factor - aii_outer) { - module.vec_znx_rotate( - b_inner, - &mut acc[j as usize].data, - 0, - &lut.data[i as usize], - 0, - ); + // DFT X^{-ai} + set_xai_plus_y( + module, + ai_hi as i64, + 0, + &mut xai_plus_y_dft, + &mut xai_plus_y, + ); + // Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai} + for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) { + module.vec_znx_rotate(b_hi as i64, &mut acc[i].data, 0, &lut.data[j], 0); (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0); - module.vec_znx_dft_add_inplace( - &mut acc_add_dft[j as usize].data, - k, - &vmp_res[i as usize].data, - k, - ); + module.svp_apply_inplace(&mut vmp_res[j].data, k, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i].data, k, &vmp_res[j].data, k); }); } } }); - if i == lwe.n() - block_size { + (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { - module.vec_znx_dft_add_inplace(&mut acc_dft[0].data, i, &acc_add_dft[0].data, i); + module.vec_znx_dft_add_inplace(&mut acc_dft[j].data, i, &acc_add_dft[j].data, i); + module.vec_znx_idft(&mut acc_add_big, 0, &acc_dft[j].data, i, scratch7); + module.vec_znx_big_normalize(basek, &mut acc[j].data, i, &acc_add_big, 0, scratch7); }); - acc_dft[0].idft(module, &mut out_mut, scratch); - } else { - (0..lut.extension_factor()).for_each(|j| { - (0..cols).for_each(|i| { - module.vec_znx_dft_add_inplace(&mut acc_dft[j].data, i, &acc_add_dft[j].data, i); - }); - - acc_dft[j].idft(module, &mut acc[j], scratch); - }) - } + }); }); + + (0..cols).for_each(|i| { + module.vec_znx_copy(&mut res.data, i, &acc[0].data, i); + }); +} + +fn set_xai_plus_y( + module: &Module, + k: i64, + y: i64, + res: &mut ScalarZnxDft<&mut [u8], FFT64>, + buf: &mut ScalarZnx<&mut [u8]>, +) { + buf.zero(); + buf.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(k, buf, 0); + buf.at_mut(0, 0)[0] += y; + module.svp_prepare(res, 0, buf, 0); } pub(crate) fn cggi_blind_rotate_block_binary( @@ -270,7 +253,7 @@ pub(crate) fn cggi_blind_rotate_block_binary( let cols: usize = out_mut.rank() + 1; - mod_switch_2n(2 * module.n(), &mut lwe_2n, &lwe_ref); + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref); let a: &[i64] = &lwe_2n[1..]; let b: i64 = lwe_2n[0]; @@ -278,17 +261,17 @@ pub(crate) fn cggi_blind_rotate_block_binary( out_mut.data.zero(); // Initialize out to X^{b} * LUT(X) - module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); + module.vec_znx_rotate(-b, &mut out_mut.data, 0, &lut.data[0], 0); let block_size: usize = brk.block_size(); // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_dft, scratch1) = scratch.tmp_glwe_fourier(module, brk.basek(), out_mut.k(), out_mut.rank()); - let (mut acc_add_dft, scratch2) = scratch1.tmp_glwe_fourier(module, brk.basek(), out_mut.k(), out_mut.rank()); - let (mut vmp_res, scratch3) = scratch2.tmp_glwe_fourier(module, basek, out_mut.k(), out_mut.rank()); - let (mut xai_minus_one, scratch4) = scratch3.tmp_scalar_znx(module, 1); - let (mut xai_minus_one_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); + let (mut acc_dft, scratch1) = scratch.tmp_fourier_glwe_ct(module, brk.basek(), out_mut.k(), out_mut.rank()); + let (mut acc_add_dft, scratch2) = scratch1.tmp_fourier_glwe_ct(module, brk.basek(), out_mut.k(), out_mut.rank()); + let (mut vmp_res, scratch3) = scratch2.tmp_fourier_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + let (mut xai_plus_y, scratch4) = scratch3.tmp_scalar_znx(module, 1); + let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); let start: Instant = Instant::now(); izip!( @@ -304,15 +287,11 @@ pub(crate) fn cggi_blind_rotate_block_binary( module.vmp_apply(&mut vmp_res.data, &acc_dft.data, &skii.data, scratch5); // DFT(X^ai -1) - xai_minus_one.zero(); - xai_minus_one.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(*aii, &mut xai_minus_one, 0); - xai_minus_one.at_mut(0, 0)[0] -= 1; - module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + set_xai_plus_y(module, *aii, -1, &mut xai_plus_y_dft, &mut xai_plus_y); // DFT(X^ai -1) * (DFT(acc) * BRK[i]) (0..cols).for_each(|i| { - module.svp_apply_inplace(&mut vmp_res.data, i, &xai_minus_one_dft, 0); + module.svp_apply_inplace(&mut vmp_res.data, i, &xai_plus_y_dft, 0); module.vec_znx_dft_add_inplace(&mut acc_add_dft.data, i, &vmp_res.data, i); }); }); @@ -324,15 +303,15 @@ pub(crate) fn cggi_blind_rotate_block_binary( acc_dft.idft(module, &mut out_mut, scratch5); }); let duration: std::time::Duration = start.elapsed(); - println!("external products: {} us", duration.as_micros()); } -pub(crate) fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { +pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { let basek: usize = lwe.basek(); let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; res.copy_from_slice(&lwe.data.at(0, 0)); + res.iter_mut().for_each(|x| *x = -*x); if basek > log2n { let diff: usize = basek - log2n; diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index 8cb24cb..b7f9c3f 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -46,8 +46,13 @@ impl BlindRotationKeyCGGI { assert_eq!(sk_glwe.n(), module.n()); assert_eq!(sk_glwe.rank(), self.data[0].rank()); match sk_lwe.dist { - Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) => {} - _ => panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb"), + Distribution::BinaryBlock(_) + | Distribution::BinaryFixed(_) + | Distribution::BinaryProb(_) + | Distribution::ZERO => {} + _ => panic!( + "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" + ), } } @@ -79,6 +84,11 @@ impl BlindRotationKeyCGGI { self.data[0].k() } + #[allow(dead_code)] + pub(crate) fn size(&self) -> usize { + self.data[0].size() + } + #[allow(dead_code)] pub(crate) fn rank(&self) -> usize { self.data[0].rank() diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index b6e9a7f..ed54dd2 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -1,4 +1,4 @@ -use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; +use backend::{FFT64, Module, ScalarZnx, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, alloc_aligned}; pub struct LookUpTable { pub(crate) data: Vec>>, @@ -84,6 +84,31 @@ impl LookUpTable { } } + pub fn set_raw(&mut self, module: &Module, lut: &ScalarZnx) + where + D: AsRef<[u8]>, + { + let domain_size: usize = self.domain_size(); + + let size: usize = self.k.div_ceil(self.basek); + + let mut lut_full: VecZnx> = VecZnx::new::(domain_size, 1, size); + + lut_full.at_mut(0, 0).copy_from_slice(lut.raw()); + + if self.extension_factor() > 1 { + (0..self.extension_factor()).for_each(|i| { + module.switch_degree(&mut self.data[i], 0, &lut_full, 0); + if i < self.extension_factor() { + lut_full.rotate(-1); + } + }); + } else { + module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); + } + } + + #[allow(dead_code)] pub(crate) fn rotate(&mut self, k: i64) { let extension_factor: usize = self.extension_factor(); let two_n: usize = 2 * self.data[0].n(); diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index 6322c75..5deb497 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -1,12 +1,12 @@ use std::time::Instant; -use backend::{Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxView}; +use backend::{Encoding, FFT64, Module, ScalarZnx, ScratchOwned, Stats, VecZnxOps, ZnxView, ZnxViewMut}; use sampling::source::Source; use crate::{ FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, blind_rotation::{ - ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, mod_switch_2n}, + ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n}, key::BlindRotationKeyCGGI, lut::LookUpTable, }, @@ -16,22 +16,24 @@ use crate::{ #[test] fn blind_rotation() { let module: Module = Module::::new(2048); - let basek: usize = 20; + let basek: usize = 18; let n_lwe: usize = 1071; - let k_lwe: usize = 22; - let k_brk: usize = 60; + let k_lwe: usize = 24; + let k_brk: usize = 3 * basek; let rows_brk: usize = 2; - let k_lut: usize = 60; + let k_lut: usize = 2 * basek; let rank: usize = 1; let block_size: usize = 7; - let message_modulus: usize = 64; + let extension_factor: usize = 2; - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); + let message_modulus: usize = 1 << 6; + + let mut source_xs: Source = Source::new([1u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([1u8; 32]); let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); @@ -40,9 +42,21 @@ fn blind_rotation() { let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); sk_lwe.fill_binary_block(block_size, &mut source_xs); + sk_lwe.data.raw_mut()[0] = 0; + + println!("sk_lwe: {:?}", sk_lwe.data.raw()); + let mut scratch: ScratchOwned = ScratchOwned::new( BlindRotationKeyCGGI::generate_from_sk_scratch_space(&module, basek, k_brk, rank) - | cggi_blind_rotate_scratch_space(&module, basek, k_lut, k_brk, rows_brk, rank), + | cggi_blind_rotate_scratch_space( + &module, + extension_factor, + basek, + k_lut, + k_brk, + rows_brk, + rank, + ), ); let start: Instant = Instant::now(); @@ -65,8 +79,8 @@ fn blind_rotation() { let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); - let x: i64 = 0; - let bits: usize = 6; + let x: i64 = 1; + let bits: usize = 8; pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); @@ -82,7 +96,7 @@ fn blind_rotation() { 2 * x + 1 } - let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, 1); + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); lut.set(&module, lut_fn, message_modulus); let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); @@ -103,7 +117,7 @@ fn blind_rotation() { let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space - mod_switch_2n(module.n() * 2, &mut lwe_2n, &lwe.to_ref()); + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref()); let pt_want: i64 = (lwe_2n[0] + lwe_2n[1..] @@ -111,15 +125,22 @@ fn blind_rotation() { .zip(sk_lwe.data.at(0, 0)) .map(|(x, y)| x * y) .sum::()) - % (module.n() as i64 * 2); + % (2 * lut.domain_size()) as i64; - module.vec_znx_rotate_inplace(pt_want, &mut lut.data[0], 0); + println!("pt_want: {}", pt_want); - println!("pt_want: {}", lut.data[0]); + lut.rotate(pt_want); + lut.data.iter().for_each(|d| { + println!("{}", d); + }); + + // First limb should be exactly equal (test are parameterized such that the noise does not reach + // the first limb) + // assert_eq!(pt_have.data.at_mut(0, 0), lut.data[0].at_mut(0, 0)); + + // Then checks the noise module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0); - let noise: f64 = lut.data[0].std(0, basek); - println!("noise: {}", noise); } diff --git a/core/src/blind_rotation/test_fft64/lut.rs b/core/src/blind_rotation/test_fft64/lut.rs index 9377d76..58c393b 100644 --- a/core/src/blind_rotation/test_fft64/lut.rs +++ b/core/src/blind_rotation/test_fft64/lut.rs @@ -25,7 +25,7 @@ fn standard() { let step: usize = lut.domain_size().div_round(message_modulus); (0..lut.domain_size()).step_by(step).for_each(|i| { - (0..step).for_each(|j| { + (0..step).for_each(|_| { assert_eq!( lut_fn((i / step) as i64) % message_modulus as i64, lut.data[0].raw()[0] / scale as i64 @@ -58,7 +58,7 @@ fn extended() { let step: usize = lut.domain_size().div_round(message_modulus); (0..lut.domain_size()).step_by(step).for_each(|i| { - (0..step).for_each(|j| { + (0..step).for_each(|_| { assert_eq!( lut_fn((i / step) as i64) % message_modulus as i64, lut.data[0].raw()[0] / scale as i64 diff --git a/core/src/gglwe/automorphism.rs b/core/src/gglwe/automorphism.rs index e18e65a..07fcb14 100644 --- a/core/src/gglwe/automorphism.rs +++ b/core/src/gglwe/automorphism.rs @@ -77,7 +77,7 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { let (mut tmp_idft_data, scratct1) = scratch.tmp_vec_znx_big(module, cols_out, self.size()); { - let (mut tmp_dft, scratch2) = scratct1.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_dft, scratch2) = scratct1.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); // Extracts relevant row lhs.get_row(module, row_j, col_i, &mut tmp_dft); @@ -109,7 +109,7 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { tmp_idft.keyswitch_inplace(module, &rhs.key, scratct1); { - let (mut tmp_dft, _) = scratct1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_dft, _) = scratct1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) // and switches back to DFT domain @@ -124,7 +124,7 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { }); }); - let (mut tmp_dft, _) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_dft, _) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); tmp_dft.data.zero(); (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { diff --git a/core/src/gglwe/encryption.rs b/core/src/gglwe/encryption.rs index 6c31b2a..bc1137c 100644 --- a/core/src/gglwe/encryption.rs +++ b/core/src/gglwe/encryption.rs @@ -70,7 +70,7 @@ impl + AsRef<[u8]>> GGLWECiphertext { let (mut tmp_pt, scrach_1) = scratch.tmp_glwe_pt(module, basek, k); let (mut tmp_ct, scrach_2) = scrach_1.tmp_glwe_ct(module, basek, k, rank_out); - let (mut tmp_ct_dft, scratch_3) = scrach_2.tmp_glwe_fourier(module, basek, k, rank_out); + let (mut tmp_ct_dft, scratch_3) = scrach_2.tmp_fourier_glwe_ct(module, basek, k, rank_out); // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns // diff --git a/core/src/gglwe/external_product.rs b/core/src/gglwe/external_product.rs index 2e063ef..26a8c92 100644 --- a/core/src/gglwe/external_product.rs +++ b/core/src/gglwe/external_product.rs @@ -66,8 +66,8 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -103,7 +103,7 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); println!("tmp: {}", tmp.size()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { diff --git a/core/src/gglwe/keyswitch.rs b/core/src/gglwe/keyswitch.rs index 632309d..fe4a3f6 100644 --- a/core/src/gglwe/keyswitch.rs +++ b/core/src/gglwe/keyswitch.rs @@ -113,8 +113,8 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -150,7 +150,7 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { diff --git a/core/src/ggsw/ciphertext.rs b/core/src/ggsw/ciphertext.rs index 12e9723..e52f305 100644 --- a/core/src/ggsw/ciphertext.rs +++ b/core/src/ggsw/ciphertext.rs @@ -290,7 +290,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { // Switch vec_znx_ct into DFT domain { - let (mut tmp_ct_dft, _) = scratch2.tmp_glwe_fourier(module, basek, k, rank); + let (mut tmp_ct_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, k, rank); tmp_ct.dft(module, &mut tmp_ct_dft); self.set_row(module, row_i, col_j, &tmp_ct_dft); } @@ -438,7 +438,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) (1..cols).for_each(|col_j| { self.expand_row(module, col_j, &mut tmp_res.data, &ci_dft, tsk, scratch2); - let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank); + let (mut tmp_res_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, self.k(), rank); tmp_res.dft(module, &mut tmp_res_dft); self.set_row(module, row_i, col_j, &tmp_res_dft); }); @@ -541,7 +541,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { tensor_key, scratch2, ); - let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank); + let (mut tmp_res_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, self.k(), rank); tmp_res.dft(module, &mut tmp_res_dft); self.set_row(module, row_i, col_j, &tmp_res_dft); }); @@ -599,8 +599,8 @@ impl + AsRef<[u8]>> GGSWCiphertext { ) } - let (mut tmp_ct_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_ct_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_ct_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_ct_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank() + 1).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -636,7 +636,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { ); } - let (mut tmp_ct, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_ct, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank() + 1).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -674,7 +674,7 @@ impl> GGSWCiphertext { ) ) } - let (mut tmp_dft_dft, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_dft_dft, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); self.get_row(module, row_i, 0, &mut tmp_dft_dft); res.keyswitch_from_fourier(module, &tmp_dft_dft, ksk, scratch1); } diff --git a/core/src/lib.rs b/core/src/lib.rs index afdcc55..ba28589 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -28,6 +28,14 @@ pub(crate) const SIX_SIGMA: f64 = 6.0; pub trait ScratchCore { fn tmp_glwe_ct(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self); + fn tmp_vec_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self); fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self); fn tmp_gglwe( &mut self, @@ -48,13 +56,21 @@ pub trait ScratchCore { digits: usize, rank: usize, ) -> (GGSWCiphertext<&mut [u8], B>, &mut Self); - fn tmp_glwe_fourier( + fn tmp_fourier_glwe_ct( &mut self, module: &Module, basek: usize, k: usize, rank: usize, ) -> (FourierGLWECiphertext<&mut [u8], B>, &mut Self); + fn tmp_vec_fourier_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self); fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); fn tmp_fourier_sk(&mut self, module: &Module, rank: usize) -> (FourierGLWESecret<&mut [u8], B>, &mut Self); fn tmp_glwe_pk( @@ -106,6 +122,24 @@ impl ScratchCore for Scratch { (GLWECiphertext { data, basek, k }, scratch) } + fn tmp_vec_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut cts: Vec> = Vec::with_capacity(size); + for _ in 0..size { + let (ct, new_scratch) = scratch.tmp_glwe_ct(module, basek, k, rank); + scratch = new_scratch; + cts.push(ct); + } + (cts, scratch) + } + fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { let (data, scratch) = self.tmp_vec_znx(module, 1, k.div_ceil(basek)); (GLWEPlaintext { data, basek, k }, scratch) @@ -166,7 +200,7 @@ impl ScratchCore for Scratch { ) } - fn tmp_glwe_fourier( + fn tmp_fourier_glwe_ct( &mut self, module: &Module, basek: usize, @@ -177,6 +211,24 @@ impl ScratchCore for Scratch { (FourierGLWECiphertext { data, basek, k }, scratch) } + fn tmp_vec_fourier_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut cts: Vec> = Vec::with_capacity(size); + for _ in 0..size { + let (ct, new_scratch) = scratch.tmp_fourier_glwe_ct(module, basek, k, rank); + scratch = new_scratch; + cts.push(ct); + } + (cts, scratch) + } + fn tmp_glwe_pk( &mut self, module: &Module, @@ -184,7 +236,7 @@ impl ScratchCore for Scratch { k: usize, rank: usize, ) -> (GLWEPublicKey<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_glwe_fourier(module, basek, k, rank); + let (data, scratch) = self.tmp_fourier_glwe_ct(module, basek, k, rank); ( GLWEPublicKey { data, From e8454cd5f1801320f9883496a7c9c9ecc5f41c9f Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 3 Jul 2025 11:39:46 +0200 Subject: [PATCH 14/23] small fix on scratch space size --- core/src/blind_rotation/ccgi.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 87471e9..cde4682 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -1,8 +1,7 @@ use std::time::Instant; use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, - Scratch, VecZnxBigAlloc, VecZnxBigOps, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, + MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, FFT64 }; use itertools::izip; @@ -40,7 +39,7 @@ pub fn cggi_blind_rotate_scratch_space( acc = 0; } - return acc + acc_dft + acc_dft_add + vmp_res + xai_plus_y + xai_plus_y_dft + (vmp | acc_big); + return acc + acc_big + acc_dft + acc_dft_add + vmp_res + xai_plus_y + xai_plus_y_dft + (vmp | module.vec_znx_big_normalize_tmp_bytes()); } pub fn cggi_blind_rotate( From c4a517e9c3ef4eb4214d6894e035b685f6f665b9 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Fri, 4 Jul 2025 16:03:46 +0530 Subject: [PATCH 15/23] Fix `decode_vec_i64` to handle the case `k < basek` --- backend/src/encoding.rs | 63 ++++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/backend/src/encoding.rs b/backend/src/encoding.rs index 73b86a3..55bba09 100644 --- a/backend/src/encoding.rs +++ b/backend/src/encoding.rs @@ -157,18 +157,22 @@ fn decode_vec_i64>(a: &VecZnx, col_i: usize, basek: usize, k: } data.copy_from_slice(a.at(col_i, 0)); let rem: usize = basek - (k % basek); - (1..size).for_each(|i| { - if i == size - 1 && rem != basek { - let k_rem: usize = basek - rem; - izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << k_rem) + (x >> rem); - }); - } else { - izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << basek) + x; - }); - } - }) + if k < basek { + data.iter_mut().for_each(|x| *x >>= rem); + } else { + (1..size).for_each(|i| { + if i == size - 1 && rem != basek { + let k_rem: usize = basek - rem; + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << basek) + x; + }); + } + }) + } } fn decode_vec_float>(a: &VecZnx, col_i: usize, basek: usize, data: &mut [Float]) { @@ -268,7 +272,7 @@ fn decode_coeff_i64>(a: &VecZnx, col_i: usize, basek: usize, k let mut res: i64 = data[i]; let rem: usize = basek - (k % basek); let slice_size: usize = a.n() * a.cols(); - (1..size).for_each(|i| { + (0..size).for_each(|i| { let x: i64 = data[i * slice_size]; if i == size - 1 && rem != basek { let k_rem: usize = basek - rem; @@ -316,18 +320,25 @@ mod tests { let module: Module = Module::::new(n); let basek: usize = 17; let size: usize = 5; - let k: usize = size * basek - 5; - let mut a: VecZnx<_> = module.new_vec_znx(2, size); - let mut source = Source::new([0u8; 32]); - let raw: &mut [i64] = a.raw_mut(); - raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut().for_each(|x| *x = source.next_i64()); - a.encode_vec_i64(col_i, basek, k, &have, 64); - let mut want = vec![i64::default(); n]; - a.decode_vec_i64(col_i, basek, k, &mut want); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - }) + for k in [size * basek - 5] { + let mut a: VecZnx<_> = module.new_vec_znx(2, size); + let mut source = Source::new([0u8; 32]); + let raw: &mut [i64] = a.raw_mut(); + raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + (0..a.cols()).for_each(|col_i| { + let mut have: Vec = vec![i64::default(); n]; + have.iter_mut().for_each(|x| { + if k < 64 { + *x = source.next_u64n(1 << k, (1 << k) - 1) as i64; + } else { + *x = source.next_i64(); + } + }); + a.encode_vec_i64(col_i, basek, k, &have, std::cmp::min(k, 64)); + let mut want = vec![i64::default(); n]; + a.decode_vec_i64(col_i, basek, k, &mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + }) + } } } From 5234c3fc6382a22e296db1d309b808daa4f2c078 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 7 Jul 2025 11:09:04 +0200 Subject: [PATCH 16/23] Added LWE-GLWE conversion & LWE Keyswitch, improved LUT generation --- backend/src/lib.rs | 34 ++ backend/src/znx_base.rs | 8 +- core/benches/keyswitch_glwe_fft64.rs | 8 +- core/src/blind_rotation/ccgi.rs | 216 ++++---- core/src/blind_rotation/lut.rs | 64 +-- core/src/blind_rotation/test_fft64/cggi.rs | 48 +- core/src/blind_rotation/test_fft64/lut.rs | 24 +- core/src/fourier_glwe/test_fft64/keyswitch.rs | 8 +- core/src/gglwe/automorphism_key.rs | 4 +- core/src/gglwe/encryption.rs | 125 +++-- core/src/gglwe/keyswitch_key.rs | 40 +- core/src/gglwe/test_fft64/automorphism_key.rs | 12 +- core/src/gglwe/test_fft64/gglwe.rs | 31 +- core/src/gglwe/test_fft64/tensor_key.rs | 4 +- core/src/ggsw/ciphertext.rs | 2 +- core/src/ggsw/test_fft64/ggsw.rs | 32 +- core/src/glwe/keyswitch.rs | 24 +- core/src/glwe/test_fft64/automorphism.rs | 8 +- core/src/glwe/test_fft64/keyswitch.rs | 8 +- core/src/glwe/test_fft64/packing.rs | 4 +- core/src/glwe/test_fft64/trace.rs | 4 +- core/src/lib.rs | 25 +- core/src/lwe/encryption.rs | 12 +- core/src/lwe/keyswtich.rs | 313 ++++++++++++ core/src/lwe/mod.rs | 4 + core/src/lwe/test_fft64/conversion.rs | 220 ++++++++ core/src/lwe/test_fft64/mod.rs | 1 + core/src/test_fft64/glwe_fourier.rs | 478 ------------------ 28 files changed, 979 insertions(+), 782 deletions(-) create mode 100644 core/src/lwe/keyswtich.rs create mode 100644 core/src/lwe/test_fft64/conversion.rs create mode 100644 core/src/lwe/test_fft64/mod.rs delete mode 100644 core/src/test_fft64/glwe_fourier.rs diff --git a/backend/src/lib.rs b/backend/src/lib.rs index d55ba8a..e7c8e5e 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -231,6 +231,23 @@ impl Scratch { ) } + pub fn tmp_slice_vec_znx_dft( + &mut self, + slice_size: usize, + module: &Module, + cols: usize, + size: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut slice: Vec> = Vec::with_capacity(slice_size); + for _ in 0..slice_size{ + let (znx, new_scratch) = scratch.tmp_vec_znx_dft(module, cols, size); + scratch = new_scratch; + slice.push(znx); + }; + (slice, scratch) + } + pub fn tmp_vec_znx_big( &mut self, module: &Module, @@ -253,6 +270,23 @@ impl Scratch { ) } + pub fn tmp_slice_vec_znx( + &mut self, + slice_size: usize, + module: &Module, + cols: usize, + size: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut slice: Vec> = Vec::with_capacity(slice_size); + for _ in 0..slice_size{ + let (znx, new_scratch) = scratch.tmp_vec_znx(module, cols, size); + scratch = new_scratch; + slice.push(znx); + }; + (slice, scratch) + } + pub fn tmp_mat_znx_dft( &mut self, module: &Module, diff --git a/backend/src/znx_base.rs b/backend/src/znx_base.rs index daa313d..fa7ec49 100644 --- a/backend/src/znx_base.rs +++ b/backend/src/znx_base.rs @@ -57,8 +57,8 @@ pub trait ZnxView: ZnxInfos + DataView> { fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] { - assert!(i < self.cols()); - assert!(j < self.size()); + assert!(i < self.cols(), "{} >= {}", i, self.cols()); + assert!(j < self.size(), "{} >= {}", j, self.size()); } let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_ptr().add(offset) } @@ -85,8 +85,8 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] { - assert!(i < self.cols()); - assert!(j < self.size()); + assert!(i < self.cols(), "{} >= {}", i, self.cols()); + assert!(j < self.size(), "{} >= {}", j, self.size()); } let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_mut_ptr().add(offset) } diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs index 9de1e9c..4acc754 100644 --- a/core/benches/keyswitch_glwe_fft64.rs +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -38,7 +38,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_out, rank_out); let mut scratch = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_out) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( &module, @@ -63,7 +63,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); sk_out.fill_ternary_prob(0.5, &mut source_xs); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, -1, &sk_in, @@ -139,7 +139,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); let mut scratch = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), digits, rank), ); @@ -156,7 +156,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, &sk_out_dft, diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index cde4682..684ef72 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -1,13 +1,12 @@ -use std::time::Instant; - use backend::{ - MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, FFT64 + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView, + ZnxViewMut, ZnxZero, }; use itertools::izip; use crate::{ - FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWECiphertextToMut, GLWEPlaintext, Infos, LWECiphertext, - ScratchCore, + GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, lwe::ciphertext::LWECiphertextToRef, }; @@ -21,25 +20,31 @@ pub fn cggi_blind_rotate_scratch_space( rows: usize, rank: usize, ) -> usize { - let lut_size: usize = k_lut.div_ceil(basek); + let cols: usize = rank + 1; let brk_size: usize = k_brk.div_ceil(basek); - let acc_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor; - let acc_big: usize = module.bytes_of_vec_znx_big(rank + 1, brk_size); - let acc_dft_add: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor; - let vmp_res: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor; + let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor; + let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); + let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; + let acc_dft_add: usize = vmp_res; let xai_plus_y: usize = module.bytes_of_scalar_znx(1); let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); - let vmp: usize = module.vmp_apply_tmp_bytes(lut_size, lut_size, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) let acc: usize; if extension_factor > 1 { - acc = GLWECiphertext::bytes_of(module, basek, k_lut, rank) * extension_factor; + acc = module.bytes_of_vec_znx(cols, k_lut.div_ceil(basek)) * extension_factor; } else { acc = 0; } - return acc + acc_big + acc_dft + acc_dft_add + vmp_res + xai_plus_y + xai_plus_y_dft + (vmp | module.vec_znx_big_normalize_tmp_bytes()); + return acc + + acc_dft + + acc_dft_add + + vmp_res + + xai_plus_y + + xai_plus_y_dft + + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()))); } pub fn cggi_blind_rotate( @@ -62,6 +67,7 @@ pub fn cggi_blind_rotate( } } +// TODO: ENSURE DOMAIN EXTENSION AS pub(crate) fn cggi_blind_rotate_block_binary_extended( module: &Module, res: &mut GLWECiphertext, @@ -75,27 +81,25 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( { let extension_factor: usize = lut.extension_factor(); let basek: usize = res.basek(); + let rows: usize = brk.rows(); + let cols: usize = res.rank() + 1; - let (mut acc, scratch1) = scratch.tmp_vec_glwe_ct(extension_factor, module, basek, res.k(), res.rank()); - let (mut acc_dft, scratch2) = scratch1.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank()); - let (mut vmp_res, scratch3) = scratch2.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank()); - let (mut acc_add_dft, scratch4) = scratch3.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank()); - - (0..extension_factor).for_each(|i| { - acc[i].data.zero(); - }); - + let (mut acc, scratch1) = scratch.tmp_slice_vec_znx(extension_factor, module, cols, res.size()); + let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows); + let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size()); + let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size()); let (mut xai_plus_y, scratch5) = scratch4.tmp_scalar_znx(module, 1); let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1); - let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size()); + + (0..extension_factor).for_each(|i| { + acc[i].zero(); + }); let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); let two_n_ext: usize = 2 * lut.domain_size(); - let cols: usize = res.rank() + 1; - negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref); let a: &[i64] = &lwe_2n[1..]; @@ -105,10 +109,10 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( let b_lo: usize = b_pos % extension_factor; for (i, j) in (0..b_lo).zip(extension_factor - b_lo..extension_factor) { - module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i].data, 0, &lut.data[j], 0); + module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i], 0, &lut.data[j], 0); } for (i, j) in (b_lo..extension_factor).zip(0..extension_factor - b_lo) { - module.vec_znx_rotate(b_hi as i64, &mut acc[i].data, 0, &lut.data[j], 0); + module.vec_znx_rotate(b_hi as i64, &mut acc[i], 0, &lut.data[j], 0); } let block_size: usize = brk.block_size(); @@ -121,9 +125,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( .for_each(|(i, (ai, ski))| { (0..extension_factor).for_each(|i| { (0..cols).for_each(|j| { - module.vec_znx_dft(1, 0, &mut acc_dft[i].data, j, &acc[i].data, j); + module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j); }); - acc_add_dft[i].data.zero(); + acc_add_dft[i].zero(); }); // TODO: first & last iterations can be optimized @@ -134,25 +138,19 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( // vmp_res = DFT(acc) * BRK[i] (0..extension_factor).for_each(|i| { - module.vmp_apply(&mut vmp_res[i].data, &acc_dft[i].data, &skii.data, scratch7); + module.vmp_apply(&mut vmp_res[i], &acc_dft[i], &skii.data, scratch6); }); // Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1) if ai_lo == 0 { // DFT X^{-ai} - set_xai_plus_y( - module, - ai_hi as i64, - -1, - &mut xai_plus_y_dft, - &mut xai_plus_y, - ); + set_xai_plus_y(module, ai_hi, -1, &mut xai_plus_y_dft, &mut xai_plus_y); // Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1) (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { - module.svp_apply_inplace(&mut vmp_res[j].data, i, &xai_plus_y_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft[j].data, i, &vmp_res[j].data, i); + module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i); }); }); // Non trivial case: rotation between polynomials @@ -163,74 +161,83 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( // Sets acc_add_dft[i] = acc[i] * sk (0..extension_factor).for_each(|i| { (0..cols).for_each(|k| { - module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i].data, k, &vmp_res[i].data, k); + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); }) }); - // DFT X^{-ai+1} - set_xai_plus_y( - module, - ai_hi as i64 + 1, - 0, - &mut xai_plus_y_dft, - &mut xai_plus_y, - ); + // DFT X^{-ai} + set_xai_plus_y(module, ai_hi + 1, 0, &mut xai_plus_y_dft, &mut xai_plus_y); // Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1} for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) { - module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i].data, 0, &lut.data[j], 0); (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[j].data, k, &xai_plus_y_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft[i].data, k, &vmp_res[j].data, k); + module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); }); } // DFT X^{-ai} - set_xai_plus_y( - module, - ai_hi as i64, - 0, - &mut xai_plus_y_dft, - &mut xai_plus_y, - ); + set_xai_plus_y(module, ai_hi, 0, &mut xai_plus_y_dft, &mut xai_plus_y); // Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai} for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) { - module.vec_znx_rotate(b_hi as i64, &mut acc[i].data, 0, &lut.data[j], 0); (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[j].data, k, &xai_plus_y_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft[i].data, k, &vmp_res[j].data, k); + module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); }); } } }); - (0..extension_factor).for_each(|j| { - (0..cols).for_each(|i| { - module.vec_znx_dft_add_inplace(&mut acc_dft[j].data, i, &acc_add_dft[j].data, i); - module.vec_znx_idft(&mut acc_add_big, 0, &acc_dft[j].data, i, scratch7); - module.vec_znx_big_normalize(basek, &mut acc[j].data, i, &acc_add_big, 0, scratch7); + { + let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size()); + + (0..extension_factor).for_each(|j| { + (0..cols).for_each(|i| { + module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7); + module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i); + module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7); + }); }); - }); + } }); (0..cols).for_each(|i| { - module.vec_znx_copy(&mut res.data, i, &acc[0].data, i); + module.vec_znx_copy(&mut res.data, i, &acc[0], i); }); } fn set_xai_plus_y( module: &Module, - k: i64, + ai: usize, y: i64, res: &mut ScalarZnxDft<&mut [u8], FFT64>, buf: &mut ScalarZnx<&mut [u8]>, ) { - buf.zero(); - buf.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(k, buf, 0); - buf.at_mut(0, 0)[0] += y; + let n: usize = module.n(); + + { + let raw: &mut [i64] = buf.at_mut(0, 0); + if ai < n { + raw[ai] = 1; + } else { + raw[(ai - n) & (n - 1)] = -1; + } + raw[0] += y; + } + module.svp_prepare(res, 0, buf, 0); + + { + let raw: &mut [i64] = buf.at_mut(0, 0); + + if ai < n { + raw[ai] = 0; + } else { + raw[(ai - n) & (n - 1)] = 0; + } + raw[0] = 0; + } } pub(crate) fn cggi_blind_rotate_block_binary( @@ -244,11 +251,12 @@ pub(crate) fn cggi_blind_rotate_block_binary( DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, { - let basek: usize = res.basek(); - let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let two_n: usize = module.n() << 1; + let basek: usize = brk.basek(); + let rows = brk.rows(); let cols: usize = out_mut.rank() + 1; @@ -260,48 +268,59 @@ pub(crate) fn cggi_blind_rotate_block_binary( out_mut.data.zero(); // Initialize out to X^{b} * LUT(X) - module.vec_znx_rotate(-b, &mut out_mut.data, 0, &lut.data[0], 0); + module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); let block_size: usize = brk.block_size(); // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_dft, scratch1) = scratch.tmp_fourier_glwe_ct(module, brk.basek(), out_mut.k(), out_mut.rank()); - let (mut acc_add_dft, scratch2) = scratch1.tmp_fourier_glwe_ct(module, brk.basek(), out_mut.k(), out_mut.rank()); - let (mut vmp_res, scratch3) = scratch2.tmp_fourier_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows); + let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size()); + let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size()); let (mut xai_plus_y, scratch4) = scratch3.tmp_scalar_znx(module, 1); let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); - let start: Instant = Instant::now(); izip!( a.chunks_exact(block_size), brk.data.chunks_exact(block_size) ) .for_each(|(ai, ski)| { - out_mut.dft(module, &mut acc_dft); - acc_add_dft.data.zero(); + (0..cols).for_each(|j| { + module.vec_znx_dft(1, 0, &mut acc_dft, j, &out_mut.data, j); + }); + + acc_add_dft.zero(); izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { + let ai_pos: usize = ((aii + two_n as i64) % two_n as i64) as usize; + // vmp_res = DFT(acc) * BRK[i] - module.vmp_apply(&mut vmp_res.data, &acc_dft.data, &skii.data, scratch5); + module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5); // DFT(X^ai -1) - set_xai_plus_y(module, *aii, -1, &mut xai_plus_y_dft, &mut xai_plus_y); + set_xai_plus_y(module, ai_pos, -1, &mut xai_plus_y_dft, &mut xai_plus_y); // DFT(X^ai -1) * (DFT(acc) * BRK[i]) (0..cols).for_each(|i| { - module.svp_apply_inplace(&mut vmp_res.data, i, &xai_plus_y_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft.data, i, &vmp_res.data, i); + module.svp_apply_inplace(&mut vmp_res, i, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_res, i); }); }); (0..cols).for_each(|i| { - module.vec_znx_dft_add_inplace(&mut acc_dft.data, i, &acc_add_dft.data, i); + module.vec_znx_dft_add_inplace(&mut acc_dft, i, &acc_add_dft, i); }); - acc_dft.idft(module, &mut out_mut, scratch5); + { + let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size()); + + (0..cols).for_each(|i| { + module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft, i, scratch6); + module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i); + module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch6); + }); + } }); - let duration: std::time::Duration = start.elapsed(); } pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { @@ -315,7 +334,7 @@ pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphe if basek > log2n { let diff: usize = basek - log2n; res.iter_mut().for_each(|x| { - *x = div_signed_by_pow2(x, diff); + *x = div_ceil_signed_by_pow2(x, diff); }) } else { let rem: usize = basek - (log2n % basek); @@ -336,7 +355,22 @@ pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphe } #[inline(always)] -fn div_signed_by_pow2(x: &i64, k: usize) -> i64 { +fn div_round_by_pow2(x: &i64, k: usize) -> i64 { + if x >= &0 { + (x + (1 << (k - 1))) >> k + } else { + (x + (-1 << (k - 1))) >> k + } +} + +// #[inline(always)] +// fn div_floor_signed_by_pow2(x: &i64, k: usize) -> i64{ +// let bias: i64 = (1 << k) - 1; +// (x + ((x >> 63) & bias)) >> k +// } + +#[inline(always)] +fn div_ceil_signed_by_pow2(x: &i64, k: usize) -> i64 { let bias: i64 = (1 << k) - 1; (x + ((x >> 63) & bias)) >> k } diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index ed54dd2..a7fe003 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -1,4 +1,4 @@ -use backend::{FFT64, Module, ScalarZnx, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, alloc_aligned}; +use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; pub struct LookUpTable { pub(crate) data: Vec>>, @@ -24,17 +24,19 @@ impl LookUpTable { self.data.len() * self.data[0].n() } - pub fn set(&mut self, module: &Module, f: fn(i64) -> i64, message_modulus: usize) { + pub fn set(&mut self, module: &Module, f: &Vec, k: usize) { + assert!(f.len() <= module.n()); + let basek: usize = self.basek; // Get the number minimum limb to store the message modulus - let limbs: usize = message_modulus.div_ceil(1 << basek); + let limbs: usize = k.div_ceil(1 << basek); // Scaling factor - let scale: i64 = (1 << (basek * limbs - 1)).div_round(message_modulus) as i64; + let scale: i64 = (1 << (basek * limbs - 1)).div_round(k) as i64; - // Updates function - let f_scaled = |x: i64| (f(x) % message_modulus as i64) * scale; + // #elements in lookup table + let f_len: usize = f.len(); // If LUT size > module.n() let domain_size: usize = self.domain_size(); @@ -43,29 +45,17 @@ impl LookUpTable { // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) let mut lut_full: VecZnx> = VecZnx::new::(domain_size, 1, size); - { - let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); - let start: usize = 0; - let end: usize = (domain_size).div_round(message_modulus); + let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); - let y: i64 = f_scaled(0); - (start..end).for_each(|i| { - lut_at[i] = y; - }); - - (1..message_modulus).for_each(|x| { - let start: usize = (x * domain_size).div_round(message_modulus); - let end: usize = ((x + 1) * domain_size).div_round(message_modulus); - let y: i64 = f_scaled(x as i64); - (start..end).for_each(|i| { - lut_at[i] = y; - }) - }); - } + f.iter().enumerate().for_each(|(i, fi)| { + let start: usize = (i * domain_size).div_round(f_len); + let end: usize = ((i + 1) * domain_size).div_round(f_len); + lut_at[start..end].fill(fi * scale); + }); // Rotates half the step to the left - let half_step: usize = domain_size.div_round(message_modulus << 1); + let half_step: usize = domain_size.div_round(f_len << 1); lut_full.rotate(-(half_step as i64)); @@ -84,30 +74,6 @@ impl LookUpTable { } } - pub fn set_raw(&mut self, module: &Module, lut: &ScalarZnx) - where - D: AsRef<[u8]>, - { - let domain_size: usize = self.domain_size(); - - let size: usize = self.k.div_ceil(self.basek); - - let mut lut_full: VecZnx> = VecZnx::new::(domain_size, 1, size); - - lut_full.at_mut(0, 0).copy_from_slice(lut.raw()); - - if self.extension_factor() > 1 { - (0..self.extension_factor()).for_each(|i| { - module.switch_degree(&mut self.data[i], 0, &lut_full, 0); - if i < self.extension_factor() { - lut_full.rotate(-1); - } - }); - } else { - module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); - } - } - #[allow(dead_code)] pub(crate) fn rotate(&mut self, k: i64) { let extension_factor: usize = self.extension_factor(); diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index 5deb497..4a5c319 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -1,6 +1,6 @@ use std::time::Instant; -use backend::{Encoding, FFT64, Module, ScalarZnx, ScratchOwned, Stats, VecZnxOps, ZnxView, ZnxViewMut}; +use backend::{Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxView}; use sampling::source::Source; use crate::{ @@ -16,13 +16,13 @@ use crate::{ #[test] fn blind_rotation() { let module: Module = Module::::new(2048); - let basek: usize = 18; + let basek: usize = 19; let n_lwe: usize = 1071; let k_lwe: usize = 24; let k_brk: usize = 3 * basek; - let rows_brk: usize = 2; + let rows_brk: usize = 1; let k_lut: usize = 2 * basek; let rank: usize = 1; let block_size: usize = 7; @@ -42,22 +42,19 @@ fn blind_rotation() { let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); sk_lwe.fill_binary_block(block_size, &mut source_xs); - sk_lwe.data.raw_mut()[0] = 0; + let mut scratch: ScratchOwned = ScratchOwned::new(BlindRotationKeyCGGI::generate_from_sk_scratch_space( + &module, basek, k_brk, rank, + )); - println!("sk_lwe: {:?}", sk_lwe.data.raw()); - - let mut scratch: ScratchOwned = ScratchOwned::new( - BlindRotationKeyCGGI::generate_from_sk_scratch_space(&module, basek, k_brk, rank) - | cggi_blind_rotate_scratch_space( - &module, - extension_factor, - basek, - k_lut, - k_brk, - rows_brk, - rank, - ), - ); + let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space( + &module, + extension_factor, + basek, + k_lut, + k_brk, + rows_brk, + rank, + )); let start: Instant = Instant::now(); let mut brk: BlindRotationKeyCGGI = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); @@ -79,7 +76,7 @@ fn blind_rotation() { let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); - let x: i64 = 1; + let x: i64 = 2; let bits: usize = 8; pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); @@ -92,18 +89,19 @@ fn blind_rotation() { println!("{}", pt_lwe.data); - fn lut_fn(x: i64) -> i64 { - 2 * x + 1 - } + let mut f: Vec = vec![0i64; message_modulus]; + f.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = 2 * (i as i64) + 1); let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); - lut.set(&module, lut_fn, message_modulus); + lut.set(&module, &f, message_modulus); let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); let start: Instant = Instant::now(); - (0..1).for_each(|_| { - cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch.borrow()); + (0..32).for_each(|_| { + cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); }); let duration: std::time::Duration = start.elapsed(); diff --git a/core/src/blind_rotation/test_fft64/lut.rs b/core/src/blind_rotation/test_fft64/lut.rs index 58c393b..3738b62 100644 --- a/core/src/blind_rotation/test_fft64/lut.rs +++ b/core/src/blind_rotation/test_fft64/lut.rs @@ -1,3 +1,5 @@ +use std::vec; + use backend::{FFT64, Module, ZnxView}; use crate::blind_rotation::lut::{DivRound, LookUpTable}; @@ -12,12 +14,13 @@ fn standard() { let scale: usize = (1 << (basek - 1)) / message_modulus; - fn lut_fn(x: i64) -> i64 { - x - 8 - } + let mut f: Vec = vec![0i64; message_modulus]; + f.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i as i64) - 8); let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); - lut.set(&module, lut_fn, message_modulus); + lut.set(&module, &f, message_modulus); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; lut.rotate(half_step); @@ -27,7 +30,7 @@ fn standard() { (0..lut.domain_size()).step_by(step).for_each(|i| { (0..step).for_each(|_| { assert_eq!( - lut_fn((i / step) as i64) % message_modulus as i64, + f[i / step] % message_modulus as i64, lut.data[0].raw()[0] / scale as i64 ); lut.rotate(-1); @@ -45,12 +48,13 @@ fn extended() { let scale: usize = (1 << (basek - 1)) / message_modulus; - fn lut_fn(x: i64) -> i64 { - x - 8 - } + let mut f: Vec = vec![0i64; message_modulus]; + f.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i as i64) - 8); let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); - lut.set(&module, lut_fn, message_modulus); + lut.set(&module, &f, message_modulus); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; lut.rotate(half_step); @@ -60,7 +64,7 @@ fn extended() { (0..lut.domain_size()).step_by(step).for_each(|i| { (0..step).for_each(|_| { assert_eq!( - lut_fn((i / step) as i64) % message_modulus as i64, + f[i / step] % message_modulus as i64, lut.data[0].raw()[0] / scale as i64 ); lut.rotate(-1); diff --git a/core/src/fourier_glwe/test_fft64/keyswitch.rs b/core/src/fourier_glwe/test_fft64/keyswitch.rs index 50d56ee..61c8b27 100644 --- a/core/src/fourier_glwe/test_fft64/keyswitch.rs +++ b/core/src/fourier_glwe/test_fft64/keyswitch.rs @@ -76,7 +76,7 @@ fn test_apply( .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in, rank_out) | GLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_in) | FourierGLWECiphertext::keyswitch_scratch_space( @@ -99,7 +99,7 @@ fn test_apply( sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, &sk_out_dft, @@ -170,7 +170,7 @@ fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, dig .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) | FourierGLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), digits, rank), @@ -184,7 +184,7 @@ fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, dig sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, &sk_out_dft, diff --git a/core/src/gglwe/automorphism_key.rs b/core/src/gglwe/automorphism_key.rs index 26fea52..38a7f6e 100644 --- a/core/src/gglwe/automorphism_key.rs +++ b/core/src/gglwe/automorphism_key.rs @@ -66,7 +66,7 @@ impl> GetRow for GLWEAutomorphismKey { col_j: usize, res: &mut FourierGLWECiphertext, ) { - module.mat_znx_dft_get_row(&mut res.data, &self.key.0.data, row_i, col_j); + module.mat_znx_dft_get_row(&mut res.data, &self.key.key.data, row_i, col_j); } } @@ -78,6 +78,6 @@ impl + AsRef<[u8]>> SetRow for GLWEAutomorphismKey, ) { - module.mat_znx_dft_set_row(&mut self.key.0.data, row_i, col_j, &a.data); + module.mat_znx_dft_set_row(&mut self.key.key.data, row_i, col_j, &a.data); } } diff --git a/core/src/gglwe/encryption.rs b/core/src/gglwe/encryption.rs index bc1137c..3e0a7f4 100644 --- a/core/src/gglwe/encryption.rs +++ b/core/src/gglwe/encryption.rs @@ -1,5 +1,6 @@ use backend::{ - FFT64, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxOps, ZnxInfos, ZnxZero, + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScalarZnxOps, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxOps, + ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, }; use sampling::source::Source; @@ -9,7 +10,7 @@ use crate::{ }; impl GGLWECiphertext, FFT64> { - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { let size = k.div_ceil(basek); GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + module.bytes_of_vec_znx(rank + 1, size) @@ -17,7 +18,7 @@ impl GGLWECiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } - pub fn generate_from_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + pub fn encrypt_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { unimplemented!() } } @@ -35,20 +36,30 @@ impl + AsRef<[u8]>> GGLWECiphertext { ) { #[cfg(debug_assertions)] { - assert_eq!(self.rank_in(), pt.cols()); - assert_eq!(self.rank_out(), sk.rank()); + assert_eq!( + self.rank_in(), + pt.cols(), + "self.rank_in(): {} != pt.cols(): {}", + self.rank_in(), + pt.cols() + ); + assert_eq!( + self.rank_out(), + sk.rank(), + "self.rank_out(): {} != sk.rank(): {}", + self.rank_out(), + sk.rank() + ); assert_eq!(self.n(), module.n()); assert_eq!(sk.n(), module.n()); assert_eq!(pt.n(), module.n()); assert!( - scratch.available() - >= GGLWECiphertext::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()), - "scratch.available: {} < GGLWECiphertext::generate_from_sk_scratch_space(module, self.rank()={}, \ - self.size()={}): {}", + scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GGLWECiphertext::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) ); assert!( self.rows() * self.digits() * self.basek() <= self.k(), @@ -110,17 +121,25 @@ impl + AsRef<[u8]>> GGLWECiphertext { } impl GLWESwitchingKey, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GGLWECiphertext::generate_from_sk_scratch_space(module, basek, k, rank) + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_in: usize, rank_out: usize) -> usize { + GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k, rank_out) + + module.bytes_of_scalar_znx(rank_in) + + FourierGLWESecret::bytes_of(module, rank_out) } - pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { - GGLWECiphertext::generate_from_pk_scratch_space(module, _basek, _k, _rank) + pub fn encrypt_pk_scratch_space( + module: &Module, + _basek: usize, + _k: usize, + _rank_in: usize, + _rank_out: usize, + ) -> usize { + GGLWECiphertext::encrypt_pk_scratch_space(module, _basek, _k, _rank_out) } } impl + AsRef<[u8]>> GLWESwitchingKey { - pub fn generate_from_sk, DataSkOut: AsRef<[u8]>>( + pub fn encrypt_sk, DataSkOut: AsRef<[u8]>>( &mut self, module: &Module, sk_in: &GLWESecret, @@ -130,30 +149,62 @@ impl + AsRef<[u8]>> GLWESwitchingKey { sigma: f64, scratch: &mut Scratch, ) { - self.0.encrypt_sk( + #[cfg(debug_assertions)] + { + assert!(sk_in.n() <= module.n()); + assert!(sk_out.n() <= module.n()); + } + + let (mut sk_in_tmp, scratch1) = scratch.tmp_scalar_znx(module, sk_in.rank()); + sk_in_tmp.zero(); + + (0..sk_in.rank()).for_each(|i| { + sk_in_tmp + .at_mut(i, 0) + .iter_mut() + .step_by(module.n() / sk_in.n()) + .zip(sk_in.data.at(i, 0).iter()) + .for_each(|(x, y)| *x = *y); + }); + + let (mut sk_out_tmp, scratch2) = scratch1.tmp_fourier_glwe_secret(module, sk_out.rank()); + (0..sk_out.rank()).for_each(|i| { + sk_out_tmp + .data + .at_mut(i, 0) + .chunks_exact_mut(module.n() / sk_out.n()) + .zip(sk_out.data.at(i, 0).iter()) + .for_each(|(a_chunk, &b_elem)| { + a_chunk.fill(b_elem); + }); + }); + + self.key.encrypt_sk( module, - &sk_in.data, - sk_out, + &sk_in_tmp, + &sk_out_tmp, source_xa, source_xe, sigma, - scratch, + scratch2, ); + self.sk_in_n = sk_in.n(); + self.sk_out_n = sk_out.n(); } } impl GLWEAutomorphismKey, FFT64> { - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GGLWECiphertext::generate_from_sk_scratch_space(module, basek, k, rank) + GLWESecret::bytes_of(module, rank) + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecret::bytes_of(module, rank) } - pub fn generate_from_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { - GGLWECiphertext::generate_from_pk_scratch_space(module, _basek, _k, _rank) + pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + GLWESwitchingKey::encrypt_pk_scratch_space(module, _basek, _k, _rank, _rank) } } impl + AsRef<[u8]>> GLWEAutomorphismKey { - pub fn generate_from_sk>( + pub fn encrypt_sk>( &mut self, module: &Module, p: i64, @@ -170,21 +221,19 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { assert_eq!(self.rank_out(), self.rank_in()); assert_eq!(sk.rank(), self.rank()); assert!( - scratch.available() - >= GLWEAutomorphismKey::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()), - "scratch.available(): {} < AutomorphismKey::generate_from_sk_scratch_space(module, self.rank()={}, \ - self.size()={}): {}", + scratch.available() >= GLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GLWEAutomorphismKey::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + GLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) ) } - let (mut sk_out_dft, scratch_1) = scratch.tmp_fourier_sk(module, sk.rank()); + let (mut sk_out_dft, scratch_1) = scratch.tmp_fourier_glwe_secret(module, sk.rank()); { - let (mut sk_out, _) = scratch_1.tmp_sk(module, sk.rank()); + let (mut sk_out, _) = scratch_1.tmp_glwe_secret(module, sk.rank()); (0..self.rank()).for_each(|i| { module.scalar_znx_automorphism( module.galois_element_inv(p), @@ -197,7 +246,7 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { sk_out_dft.set(module, &sk_out); } - self.key.generate_from_sk( + self.key.encrypt_sk( module, &sk, &sk_out_dft, @@ -212,15 +261,15 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { } impl GLWETensorKey, FFT64> { - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { GLWESecret::bytes_of(module, 1) + FourierGLWESecret::bytes_of(module, 1) - + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank) + + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) } } impl + AsRef<[u8]>> GLWETensorKey { - pub fn generate_from_sk>( + pub fn encrypt_sk>( &mut self, module: &Module, sk: &FourierGLWESecret, @@ -238,15 +287,15 @@ impl + AsRef<[u8]>> GLWETensorKey { let rank: usize = self.rank(); - let (mut sk_ij, scratch1) = scratch.tmp_sk(module, 1); - let (mut sk_ij_dft, scratch2) = scratch1.tmp_fourier_sk(module, 1); + let (mut sk_ij, scratch1) = scratch.tmp_glwe_secret(module, 1); + let (mut sk_ij_dft, scratch2) = scratch1.tmp_fourier_glwe_secret(module, 1); (0..rank).for_each(|i| { (i..rank).for_each(|j| { module.svp_apply(&mut sk_ij_dft.data, 0, &sk.data, i, &sk.data, j); module.scalar_znx_idft(&mut sk_ij.data, 0, &sk_ij_dft.data, 0, scratch2); self.at_mut(i, j) - .generate_from_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch2); + .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch2); }); }) } diff --git a/core/src/gglwe/keyswitch_key.rs b/core/src/gglwe/keyswitch_key.rs index 965d596..cb1c1fb 100644 --- a/core/src/gglwe/keyswitch_key.rs +++ b/core/src/gglwe/keyswitch_key.rs @@ -2,7 +2,11 @@ use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module}; use crate::{FourierGLWECiphertext, GGLWECiphertext, GetRow, Infos, SetRow}; -pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); +pub struct GLWESwitchingKey { + pub(crate) key: GGLWECiphertext, + pub(crate) sk_in_n: usize, // Degree of sk_in + pub(crate) sk_out_n: usize, // Degree of sk_out +} impl GLWESwitchingKey, FFT64> { pub fn alloc( @@ -14,9 +18,11 @@ impl GLWESwitchingKey, FFT64> { rank_in: usize, rank_out: usize, ) -> Self { - GLWESwitchingKey(GGLWECiphertext::alloc( - module, basek, k, rows, digits, rank_in, rank_out, - )) + GLWESwitchingKey { + key: GGLWECiphertext::alloc(module, basek, k, rows, digits, rank_in, rank_out), + sk_in_n: 0, + sk_out_n: 0, + } } pub fn bytes_of( @@ -36,33 +42,41 @@ impl Infos for GLWESwitchingKey { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { - self.0.inner() + self.key.inner() } fn basek(&self) -> usize { - self.0.basek() + self.key.basek() } fn k(&self) -> usize { - self.0.k() + self.key.k() } } impl GLWESwitchingKey { pub fn rank(&self) -> usize { - self.0.data.cols_out() - 1 + self.key.data.cols_out() - 1 } pub fn rank_in(&self) -> usize { - self.0.data.cols_in() + self.key.data.cols_in() } pub fn rank_out(&self) -> usize { - self.0.data.cols_out() - 1 + self.key.data.cols_out() - 1 } pub fn digits(&self) -> usize { - self.0.digits() + self.key.digits() + } + + pub fn sk_degree_in(&self) -> usize { + self.sk_in_n + } + + pub fn sk_degree_out(&self) -> usize { + self.sk_out_n } } @@ -74,7 +88,7 @@ impl> GetRow for GLWESwitchingKey { col_j: usize, res: &mut FourierGLWECiphertext, ) { - module.mat_znx_dft_get_row(&mut res.data, &self.0.data, row_i, col_j); + module.mat_znx_dft_get_row(&mut res.data, &self.key.data, row_i, col_j); } } @@ -86,6 +100,6 @@ impl + AsRef<[u8]>> SetRow for GLWESwitchingKey col_j: usize, a: &FourierGLWECiphertext, ) { - module.mat_znx_dft_set_row(&mut self.0.data, row_i, col_j, &a.data); + module.mat_znx_dft_set_row(&mut self.key.data, row_i, col_j, &a.data); } } diff --git a/core/src/gglwe/test_fft64/automorphism_key.rs b/core/src/gglwe/test_fft64/automorphism_key.rs index 61f6825..c6dc212 100644 --- a/core/src/gglwe/test_fft64/automorphism_key.rs +++ b/core/src/gglwe/test_fft64/automorphism_key.rs @@ -70,7 +70,7 @@ fn test_automorphism( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) + GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWEAutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), ); @@ -79,7 +79,7 @@ fn test_automorphism( sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 - auto_key_in.generate_from_sk( + auto_key_in.encrypt_sk( &module, p0, &sk, @@ -90,7 +90,7 @@ fn test_automorphism( ); // gglwe_{s2}(s1) -> s1 -> s2 - auto_key_apply.generate_from_sk( + auto_key_apply.encrypt_sk( &module, p1, &sk, @@ -185,7 +185,7 @@ fn test_automorphism_inplace( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) + GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_in) | GLWEAutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), ); @@ -194,7 +194,7 @@ fn test_automorphism_inplace( sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 - auto_key.generate_from_sk( + auto_key.encrypt_sk( &module, p0, &sk, @@ -205,7 +205,7 @@ fn test_automorphism_inplace( ); // gglwe_{s2}(s1) -> s1 -> s2 - auto_key_apply.generate_from_sk( + auto_key_apply.encrypt_sk( &module, p1, &sk, diff --git a/core/src/gglwe/test_fft64/gglwe.rs b/core/src/gglwe/test_fft64/gglwe.rs index 0e3796f..492d3b8 100644 --- a/core/src/gglwe/test_fft64/gglwe.rs +++ b/core/src/gglwe/test_fft64/gglwe.rs @@ -144,7 +144,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, digits: usize, rank let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in, rank_out) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ksk), ); @@ -155,7 +155,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, digits: usize, rank sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, &sk_out_dft, @@ -234,8 +234,13 @@ fn test_key_switch( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in_s0s1 | rank_out_s0s1) - | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) + GLWESwitchingKey::encrypt_sk_scratch_space( + &module, + basek, + k_ksk, + rank_in_s0s1, + rank_in_s0s1 | rank_out_s0s1, + ) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWESwitchingKey::keyswitch_scratch_space( &module, basek, @@ -260,7 +265,7 @@ fn test_key_switch( let sk2_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.generate_from_sk( + ct_gglwe_s0s1.encrypt_sk( &module, &sk0, &sk1_dft, @@ -271,7 +276,7 @@ fn test_key_switch( ); // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.generate_from_sk( + ct_gglwe_s1s2.encrypt_sk( &module, &sk1, &sk2_dft, @@ -348,7 +353,7 @@ fn test_key_switch_inplace( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in, rank_out) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ksk) | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, rank_out), ); @@ -365,7 +370,7 @@ fn test_key_switch_inplace( let sk2_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.generate_from_sk( + ct_gglwe_s0s1.encrypt_sk( &module, &sk0, &sk1_dft, @@ -376,7 +381,7 @@ fn test_key_switch_inplace( ); // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.generate_from_sk( + ct_gglwe_s1s2.encrypt_sk( &module, &sk1, &sk2_dft, @@ -459,7 +464,7 @@ fn test_external_product( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_in, rank_out) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_in, rank_in, rank_out) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWESwitchingKey::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank_out) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), @@ -477,7 +482,7 @@ fn test_external_product( let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_in.generate_from_sk( + ct_gglwe_in.encrypt_sk( &module, &sk_in, &sk_out_dft, @@ -580,7 +585,7 @@ fn test_external_product_inplace( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ct, rank_out) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ct, rank_in, rank_out) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) | GLWESwitchingKey::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank_out) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), @@ -598,7 +603,7 @@ fn test_external_product_inplace( let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe.generate_from_sk( + ct_gglwe.encrypt_sk( &module, &sk_in, &sk_out_dft, diff --git a/core/src/gglwe/test_fft64/tensor_key.rs b/core/src/gglwe/test_fft64/tensor_key.rs index be69625..ab1d191 100644 --- a/core/src/gglwe/test_fft64/tensor_key.rs +++ b/core/src/gglwe/test_fft64/tensor_key.rs @@ -23,7 +23,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new(GLWETensorKey::generate_from_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::new(GLWETensorKey::encrypt_sk_scratch_space( &module, basek, tensor_key.k(), @@ -34,7 +34,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - tensor_key.generate_from_sk( + tensor_key.encrypt_sk( &module, &sk_dft, &mut source_xa, diff --git a/core/src/ggsw/ciphertext.rs b/core/src/ggsw/ciphertext.rs index e52f305..8239617 100644 --- a/core/src/ggsw/ciphertext.rs +++ b/core/src/ggsw/ciphertext.rs @@ -361,7 +361,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { // = // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) (1..cols).for_each(|col_i| { - let pmat: &MatZnxDft = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j]) + let pmat: &MatZnxDft = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) // Extracts a[i] and multipies with Enc(s[i]s[j]) (0..digits).for_each(|di| { diff --git a/core/src/ggsw/test_fft64/ggsw.rs b/core/src/ggsw/test_fft64/ggsw.rs index 714ed2c..dc84eb6 100644 --- a/core/src/ggsw/test_fft64/ggsw.rs +++ b/core/src/ggsw/test_fft64/ggsw.rs @@ -223,8 +223,8 @@ fn test_keyswitch( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWETensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank, rank) + | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_scratch_space( &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), @@ -240,7 +240,7 @@ fn test_keyswitch( sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, &sk_out_dft, @@ -249,7 +249,7 @@ fn test_keyswitch( sigma, scratch.borrow(), ); - tsk.generate_from_sk( + tsk.encrypt_sk( &module, &sk_out_dft, &mut source_xa, @@ -352,8 +352,8 @@ fn test_keyswitch_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWETensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank, rank) + | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); @@ -367,7 +367,7 @@ fn test_keyswitch_inplace( sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, &sk_out_dft, @@ -376,7 +376,7 @@ fn test_keyswitch_inplace( sigma, scratch.borrow(), ); - tsk.generate_from_sk( + tsk.encrypt_sk( &module, &sk_out_dft, &mut source_xa, @@ -489,8 +489,8 @@ fn test_automorphism( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWETensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_scratch_space( &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), @@ -502,7 +502,7 @@ fn test_automorphism( sk.fill_ternary_prob(var_xs, &mut source_xs); let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - auto_key.generate_from_sk( + auto_key.encrypt_sk( &module, p, &sk, @@ -511,7 +511,7 @@ fn test_automorphism( sigma, scratch.borrow(), ); - tensor_key.generate_from_sk( + tensor_key.encrypt_sk( &module, &sk_dft, &mut source_xa, @@ -615,8 +615,8 @@ fn test_automorphism_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWETensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); @@ -626,7 +626,7 @@ fn test_automorphism_inplace( sk.fill_ternary_prob(var_xs, &mut source_xs); let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - auto_key.generate_from_sk( + auto_key.encrypt_sk( &module, p, &sk, @@ -635,7 +635,7 @@ fn test_automorphism_inplace( sigma, scratch.borrow(), ); - tensor_key.generate_from_sk( + tensor_key.encrypt_sk( &module, &sk_dft, &mut source_xa, diff --git a/core/src/glwe/keyswitch.rs b/core/src/glwe/keyswitch.rs index 5fb12fe..ca6dcad 100644 --- a/core/src/glwe/keyswitch.rs +++ b/core/src/glwe/keyswitch.rs @@ -87,8 +87,20 @@ impl + AsMut<[u8]>> GLWECiphertext { #[cfg(debug_assertions)] { - assert_eq!(lhs.rank(), rhs.rank_in()); - assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!( + lhs.rank(), + rhs.rank_in(), + "lhs.rank(): {} != rhs.rank_in(): {}", + lhs.rank(), + rhs.rank_in() + ); + assert_eq!( + self.rank(), + rhs.rank_out(), + "self.rank(): {} != rhs.rank_out(): {}", + self.rank(), + rhs.rank_out() + ); assert_eq!(self.basek(), basek); assert_eq!(lhs.basek(), basek); assert_eq!(rhs.n(), module.n()); @@ -141,9 +153,9 @@ impl + AsMut<[u8]>> GLWECiphertext { }); if di == 0 { - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); + module.vmp_apply(&mut res_dft, &ai_dft, &rhs.key.data, scratch2); } else { - module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); + module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.key.data, di, scratch2); } }); } @@ -225,9 +237,9 @@ impl + AsMut<[u8]>> GLWECiphertext { }); if di == 0 { - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); + module.vmp_apply(&mut res_dft, &ai_dft, &rhs.key.data, scratch2); } else { - module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); + module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.key.data, di, scratch2); } }); } diff --git a/core/src/glwe/test_fft64/automorphism.rs b/core/src/glwe/test_fft64/automorphism.rs index ba739c6..0b917ef 100644 --- a/core/src/glwe/test_fft64/automorphism.rs +++ b/core/src/glwe/test_fft64/automorphism.rs @@ -67,7 +67,7 @@ fn test_automorphism( .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) + GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, autokey.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) | GLWECiphertext::automorphism_scratch_space( @@ -85,7 +85,7 @@ fn test_automorphism( sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - autokey.generate_from_sk( + autokey.encrypt_sk( &module, p, &sk, @@ -164,7 +164,7 @@ fn test_automorphism_inplace( .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) + GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, autokey.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::automorphism_inplace_scratch_space(&module, basek, ct.k(), autokey.k(), digits, rank), @@ -174,7 +174,7 @@ fn test_automorphism_inplace( sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - autokey.generate_from_sk( + autokey.encrypt_sk( &module, p, &sk, diff --git a/core/src/glwe/test_fft64/keyswitch.rs b/core/src/glwe/test_fft64/keyswitch.rs index fb54204..9142292 100644 --- a/core/src/glwe/test_fft64/keyswitch.rs +++ b/core/src/glwe/test_fft64/keyswitch.rs @@ -72,7 +72,7 @@ fn test_keyswitch( .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_out) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( @@ -95,7 +95,7 @@ fn test_keyswitch( sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, &sk_out_dft, @@ -163,7 +163,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ct_grlwe.k(), rank) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ct_grlwe.k(), rank, rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_grlwe.k(), digits, rank), @@ -177,7 +177,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ct_grlwe.generate_from_sk( + ct_grlwe.encrypt_sk( &module, &sk_in, &sk_out_dft, diff --git a/core/src/glwe/test_fft64/packing.rs b/core/src/glwe/test_fft64/packing.rs index 0e9ee71..592d38e 100644 --- a/core/src/glwe/test_fft64/packing.rs +++ b/core/src/glwe/test_fft64/packing.rs @@ -26,7 +26,7 @@ fn apply() { let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct) | GLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) | GLWEPacker::scratch_space(&module, basek, k_ct, k_ksk, digits, rank), ); @@ -46,7 +46,7 @@ fn apply() { let mut auto_keys: HashMap, FFT64>> = HashMap::new(); gal_els.iter().for_each(|gal_el| { let mut key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - key.generate_from_sk( + key.encrypt_sk( &module, *gal_el, &sk, diff --git a/core/src/glwe/test_fft64/trace.rs b/core/src/glwe/test_fft64/trace.rs index fe4e1eb..e34e260 100644 --- a/core/src/glwe/test_fft64/trace.rs +++ b/core/src/glwe/test_fft64/trace.rs @@ -35,7 +35,7 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWEAutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_autokey, rank) + | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_autokey, rank) | GLWECiphertext::trace_inplace_scratch_space(&module, basek, ct.k(), k_autokey, digits, rank), ); @@ -68,7 +68,7 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us gal_els.iter().for_each(|gal_el| { let mut key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_autokey, rows, digits, rank); - key.generate_from_sk( + key.encrypt_sk( &module, *gal_el, &sk, diff --git a/core/src/lib.rs b/core/src/lib.rs index ba28589..f94963c 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -63,7 +63,7 @@ pub trait ScratchCore { k: usize, rank: usize, ) -> (FourierGLWECiphertext<&mut [u8], B>, &mut Self); - fn tmp_vec_fourier_glwe_ct( + fn tmp_slice_fourier_glwe_ct( &mut self, size: usize, module: &Module, @@ -71,8 +71,8 @@ pub trait ScratchCore { k: usize, rank: usize, ) -> (Vec>, &mut Self); - fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); - fn tmp_fourier_sk(&mut self, module: &Module, rank: usize) -> (FourierGLWESecret<&mut [u8], B>, &mut Self); + fn tmp_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); + fn tmp_fourier_glwe_secret(&mut self, module: &Module, rank: usize) -> (FourierGLWESecret<&mut [u8], B>, &mut Self); fn tmp_glwe_pk( &mut self, module: &Module, @@ -211,7 +211,7 @@ impl ScratchCore for Scratch { (FourierGLWECiphertext { data, basek, k }, scratch) } - fn tmp_vec_fourier_glwe_ct( + fn tmp_slice_fourier_glwe_ct( &mut self, size: usize, module: &Module, @@ -246,7 +246,7 @@ impl ScratchCore for Scratch { ) } - fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self) { + fn tmp_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self) { let (data, scratch) = self.tmp_scalar_znx(module, rank); ( GLWESecret { @@ -257,7 +257,11 @@ impl ScratchCore for Scratch { ) } - fn tmp_fourier_sk(&mut self, module: &Module, rank: usize) -> (FourierGLWESecret<&mut [u8], FFT64>, &mut Self) { + fn tmp_fourier_glwe_secret( + &mut self, + module: &Module, + rank: usize, + ) -> (FourierGLWESecret<&mut [u8], FFT64>, &mut Self) { let (data, scratch) = self.tmp_scalar_znx_dft(module, rank); ( FourierGLWESecret { @@ -279,7 +283,14 @@ impl ScratchCore for Scratch { rank_out: usize, ) -> (GLWESwitchingKey<&mut [u8], FFT64>, &mut Self) { let (data, scratch) = self.tmp_gglwe(module, basek, k, rows, digits, rank_in, rank_out); - (GLWESwitchingKey(data), scratch) + ( + GLWESwitchingKey { + key: data, + sk_in_n: 0, + sk_out_n: 0, + }, + scratch, + ) } fn tmp_autokey( diff --git a/core/src/lwe/encryption.rs b/core/src/lwe/encryption.rs index 148e5c4..00d814f 100644 --- a/core/src/lwe/encryption.rs +++ b/core/src/lwe/encryption.rs @@ -28,7 +28,9 @@ where self.data.fill_uniform(basek, 0, self.size(), source_xa); let mut tmp_znx: VecZnx> = VecZnx::>::new::(1, 1, self.size()); - (0..self.size()).for_each(|i| { + let min_size = self.size().min(pt.size()); + + (0..min_size).for_each(|i| { tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] - self.data.at(0, i)[1..] .iter() @@ -37,6 +39,14 @@ where .sum::(); }); + (min_size..self.size()).for_each(|i| { + tmp_znx.at_mut(0, i)[0] -= self.data.at(0, i)[1..] + .iter() + .zip(sk.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::(); + }); + tmp_znx.add_normal(basek, 0, self.k(), source_xe, sigma, sigma * SIX_SIGMA); let mut tmp_bytes: Vec = alloc_aligned(size_of::()); diff --git a/core/src/lwe/keyswtich.rs b/core/src/lwe/keyswtich.rs new file mode 100644 index 0000000..d06b7aa --- /dev/null +++ b/core/src/lwe/keyswtich.rs @@ -0,0 +1,313 @@ +use backend::{Backend, FFT64, Module, Scratch, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero}; +use sampling::source::Source; + +use crate::{FourierGLWESecret, GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos, LWECiphertext, LWESecret, ScratchCore}; + +/// A special [GLWESwitchingKey] required to for the conversion from [GLWECiphertext] to [LWECiphertext]. +pub struct GLWEToLWESwitchingKey(GLWESwitchingKey); + +impl GLWEToLWESwitchingKey, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank, 1)) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + FourierGLWESecret::bytes_of(module, rank) + + (GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) | GLWESecret::bytes_of(module, rank)) + } +} + +impl + AsRef<[u8]>> GLWEToLWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_lwe: &LWESecret, + sk_glwe: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DLwe: AsRef<[u8]>, + DGlwe: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(sk_lwe.n() <= module.n()); + } + + let (mut sk_lwe_as_glwe_dft, scratch1) = scratch.tmp_fourier_glwe_secret(module, 1); + + { + let (mut sk_lwe_as_glwe, _) = scratch1.tmp_glwe_secret(module, 1); + sk_lwe_as_glwe.data.zero(); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); + module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data, 0); + sk_lwe_as_glwe_dft.set(module, &sk_lwe_as_glwe); + } + + self.0.encrypt_sk( + module, + sk_glwe, + &sk_lwe_as_glwe_dft, + source_xa, + source_xe, + sigma, + scratch1, + ); + } +} + +/// A special [GLWESwitchingKey] required to for the conversion from [LWECiphertext] to [GLWECiphertext]. +pub struct LWEToGLWESwitchingKey(GLWESwitchingKey); + +impl LWEToGLWESwitchingKey, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, 1, rank)) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, rank) + GLWESecret::bytes_of(module, 1) + } +} + +impl + AsRef<[u8]>> LWEToGLWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_lwe: &LWESecret, + sk_glwe: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DLwe: AsRef<[u8]>, + DGlwe: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(sk_lwe.n() <= module.n()); + } + + let (mut sk_lwe_as_glwe, scratch1) = scratch.tmp_glwe_secret(module, 1); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); + sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data, 0); + + self.0.encrypt_sk( + module, + &sk_lwe_as_glwe, + &sk_glwe, + source_xa, + source_xe, + sigma, + scratch1, + ); + } +} + +pub struct LWESwitchingKey(GLWESwitchingKey); + +impl LWESwitchingKey, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize) -> Self { + Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, 1, 1)) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + GLWESecret::bytes_of(module, 1) + + FourierGLWESecret::bytes_of(module, 1) + + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, 1) + } +} + +impl + AsRef<[u8]>> LWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_lwe_in: &LWESecret, + sk_lwe_out: &LWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DIn: AsRef<[u8]>, + DOut: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(sk_lwe_in.n() <= module.n()); + assert!(sk_lwe_out.n() <= module.n()); + } + + let (mut sk_in_glwe, scratch1) = scratch.tmp_glwe_secret(module, 1); + let (mut sk_out_glwe, scratch2) = scratch1.tmp_fourier_glwe_secret(module, 1); + + sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_out.n()].copy_from_slice(sk_lwe_out.data.at(0, 0)); + sk_in_glwe.data.at_mut(0, 0)[sk_lwe_out.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data, 0); + sk_out_glwe.set(module, &sk_in_glwe); + sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n()].copy_from_slice(sk_lwe_in.data.at(0, 0)); + sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data, 0); + + self.0.encrypt_sk( + module, + &sk_in_glwe, + &sk_out_glwe, + source_xa, + source_xe, + sigma, + scratch2, + ); + } +} + +impl LWECiphertext> { + pub fn from_glwe_scratch_space( + module: &Module, + basek: usize, + k_lwe: usize, + k_glwe: usize, + k_ksk: usize, + rank: usize, + ) -> usize { + GLWECiphertext::bytes_of(module, basek, k_lwe, 1) + + GLWECiphertext::keyswitch_scratch_space(module, basek, k_lwe, k_glwe, k_ksk, 1, rank, 1) + } + + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_lwe_out: usize, + k_lwe_in: usize, + k_ksk: usize, + ) -> usize { + GLWECiphertext::bytes_of(module, basek, k_lwe_out.max(k_lwe_in), 1) + + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1) + } +} + +impl + AsMut<[u8]>> LWECiphertext { + pub fn sample_extract(&mut self, a: &GLWECiphertext) + where + DGlwe: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(self.n() <= a.n()); + } + + let min_size: usize = self.size().min(a.size()); + let n: usize = self.n(); + + self.data.zero(); + (0..min_size).for_each(|i| { + let data_lwe: &mut [i64] = self.data.at_mut(0, i); + data_lwe[0] = a.data.at(0, i)[0]; + data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]); + }); + } + + pub fn from_glwe( + &mut self, + module: &Module, + a: &GLWECiphertext, + ks: &GLWEToLWESwitchingKey, + scratch: &mut Scratch, + ) where + DGlwe: AsRef<[u8]>, + DKs: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.basek(), a.basek()); + } + let (mut tmp_glwe, scratch1) = scratch.tmp_glwe_ct(module, a.basek(), self.k(), 1); + tmp_glwe.keyswitch(module, a, &ks.0, scratch1); + self.sample_extract(&tmp_glwe); + } + + pub fn keyswitch( + &mut self, + module: &Module, + a: &LWECiphertext, + ksk: &LWESwitchingKey, + scratch: &mut Scratch, + ) where + A: AsRef<[u8]>, + DKs: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(self.n() <= module.n()); + assert!(a.n() <= module.n()); + assert_eq!(self.basek(), a.basek()); + } + + let max_k: usize = self.k().max(a.k()); + let basek: usize = self.basek(); + + let (mut glwe, scratch1) = scratch.tmp_glwe_ct(&module, basek, max_k, 1); + glwe.data.zero(); + + let n_lwe: usize = a.n(); + + (0..a.size()).for_each(|i| { + let data_lwe: &[i64] = a.data.at(0, i); + glwe.data.at_mut(0, i)[0] = data_lwe[0]; + glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + }); + + glwe.keyswitch_inplace(module, &ksk.0, scratch1); + + self.sample_extract(&glwe); + } +} + +impl GLWECiphertext> { + pub fn from_lwe_scratch_space( + module: &Module, + basek: usize, + k_lwe: usize, + k_glwe: usize, + k_ksk: usize, + rank: usize, + ) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, basek, k_glwe, k_lwe, k_ksk, 1, 1, rank) + + GLWECiphertext::bytes_of(module, basek, k_lwe, 1) + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn from_lwe( + &mut self, + module: &Module, + lwe: &LWECiphertext, + ksk: &LWEToGLWESwitchingKey, + scratch: &mut Scratch, + ) where + DLwe: AsRef<[u8]>, + DKsk: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(lwe.n() <= self.n()); + assert_eq!(self.basek(), self.basek()); + } + + let (mut glwe, scratch1) = scratch.tmp_glwe_ct(module, lwe.basek(), lwe.k(), 1); + glwe.data.zero(); + + let n_lwe: usize = lwe.n(); + + (0..lwe.size()).for_each(|i| { + let data_lwe: &[i64] = lwe.data.at(0, i); + glwe.data.at_mut(0, i)[0] = data_lwe[0]; + glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + }); + + self.keyswitch(module, &glwe, &ksk.0, scratch1); + } +} diff --git a/core/src/lwe/mod.rs b/core/src/lwe/mod.rs index b7bb7ed..1e3d351 100644 --- a/core/src/lwe/mod.rs +++ b/core/src/lwe/mod.rs @@ -1,9 +1,13 @@ pub mod ciphertext; pub mod decryption; pub mod encryption; +pub mod keyswtich; pub mod plaintext; pub mod secret; pub use ciphertext::LWECiphertext; pub use plaintext::LWEPlaintext; pub use secret::LWESecret; + +#[cfg(test)] +pub mod test_fft64; diff --git a/core/src/lwe/test_fft64/conversion.rs b/core/src/lwe/test_fft64/conversion.rs new file mode 100644 index 0000000..1fbd4cb --- /dev/null +++ b/core/src/lwe/test_fft64/conversion.rs @@ -0,0 +1,220 @@ +use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView}; +use sampling::source::Source; + +use crate::{ + FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, + lwe::{ + LWEPlaintext, + keyswtich::{GLWEToLWESwitchingKey, LWESwitchingKey, LWEToGLWESwitchingKey}, + }, +}; + +#[test] +fn lwe_to_glwe() { + let n: usize = 1 << 5; + let module: Module = Module::::new(n); + let basek: usize = 17; + let sigma: f64 = 3.2; + + let rank: usize = 2; + + let n_lwe: usize = 22; + let k_lwe_ct: usize = 2 * basek; + let k_lwe_pt: usize = 8; + + let k_glwe_ct: usize = 3 * basek; + + let k_ksk: usize = k_lwe_ct + basek; + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWECiphertext::from_lwe_scratch_space(&module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, k_glwe_ct), + ); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::alloc(&module, rank); + sk_glwe_dft.set(&module, &sk_glwe); + + let mut sk_lwe = LWESecret::alloc(n_lwe); + sk_lwe.fill_ternary_prob(0.5, &mut source_xs); + + let data: i64 = 17; + + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); + lwe_pt + .data + .encode_coeff_i64(0, basek, k_lwe_pt, 0, data, k_lwe_pt); + + let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + lwe_ct.encrypt_sk(&lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe, sigma); + + let mut ksk: LWEToGLWESwitchingKey, FFT64> = LWEToGLWESwitchingKey::alloc(&module, basek, k_ksk, lwe_ct.size(), rank); + + ksk.encrypt_sk( + &module, + &sk_lwe, + &sk_glwe_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_glwe_ct, rank); + glwe_ct.from_lwe(&module, &lwe_ct, &ksk, scratch.borrow()); + + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_glwe_ct); + glwe_ct.decrypt(&module, &mut glwe_pt, &sk_glwe_dft, scratch.borrow()); + + assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); +} + +#[test] +fn glwe_to_lwe() { + let n: usize = 1 << 5; + let module: Module = Module::::new(n); + let basek: usize = 17; + let sigma: f64 = 3.2; + + let rank: usize = 2; + + let n_lwe: usize = 22; + let k_lwe_ct: usize = 2 * basek; + let k_lwe_pt: usize = 8; + + let k_glwe_ct: usize = 3 * basek; + + let k_ksk: usize = k_lwe_ct + basek; + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWECiphertext::from_lwe_scratch_space(&module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, k_glwe_ct), + ); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::alloc(&module, rank); + sk_glwe_dft.set(&module, &sk_glwe); + + let mut sk_lwe = LWESecret::alloc(n_lwe); + sk_lwe.fill_ternary_prob(0.5, &mut source_xs); + + let data: i64 = 17; + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_glwe_ct); + + glwe_pt + .data + .encode_coeff_i64(0, basek, k_lwe_pt, 0, data, k_lwe_pt); + + let mut glwe_ct = GLWECiphertext::alloc(&module, basek, k_glwe_ct, rank); + glwe_ct.encrypt_sk( + &module, + &glwe_pt, + &sk_glwe_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ksk: GLWEToLWESwitchingKey, FFT64> = + GLWEToLWESwitchingKey::alloc(&module, basek, k_ksk, glwe_ct.size(), rank); + + ksk.encrypt_sk( + &module, + &sk_lwe, + &sk_glwe, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + lwe_ct.from_glwe(&module, &glwe_ct, &ksk, scratch.borrow()); + + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_ct); + lwe_ct.decrypt(&mut lwe_pt, &sk_lwe); + + assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); +} + +#[test] +fn keyswitch() { + let n: usize = 1 << 5; + let module: Module = Module::::new(n); + let basek: usize = 17; + let sigma: f64 = 3.2; + + let n_lwe_in: usize = 22; + let n_lwe_out: usize = 30; + let k_lwe_ct: usize = 2 * basek; + let k_lwe_pt: usize = 8; + + let k_ksk: usize = k_lwe_ct + basek; + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + LWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk) + | LWECiphertext::keyswitch_scratch_space(&module, basek, k_lwe_ct, k_lwe_ct, k_ksk), + ); + + let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in); + sk_lwe_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_lwe_out: LWESecret> = LWESecret::alloc(n_lwe_out); + sk_lwe_out.fill_ternary_prob(0.5, &mut source_xs); + + let data: i64 = 17; + + let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); + lwe_pt_in + .data + .encode_coeff_i64(0, basek, k_lwe_pt, 0, data, k_lwe_pt); + + let mut lwe_ct_in: LWECiphertext> = LWECiphertext::alloc(n_lwe_in, basek, k_lwe_ct); + lwe_ct_in.encrypt_sk( + &lwe_pt_in, + &sk_lwe_in, + &mut source_xa, + &mut source_xe, + sigma, + ); + + let mut ksk: LWESwitchingKey, FFT64> = LWESwitchingKey::alloc(&module, basek, k_ksk, lwe_ct_in.size()); + + ksk.encrypt_sk( + &module, + &sk_lwe_in, + &sk_lwe_out, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut lwe_ct_out: LWECiphertext> = LWECiphertext::alloc(n_lwe_out, basek, k_lwe_ct); + + lwe_ct_out.keyswitch(&module, &lwe_ct_in, &ksk, scratch.borrow()); + + let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_ct); + lwe_ct_out.decrypt(&mut lwe_pt_out, &sk_lwe_out); + + assert_eq!(lwe_pt_in.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); +} diff --git a/core/src/lwe/test_fft64/mod.rs b/core/src/lwe/test_fft64/mod.rs new file mode 100644 index 0000000..11eb2fc --- /dev/null +++ b/core/src/lwe/test_fft64/mod.rs @@ -0,0 +1 @@ +pub mod conversion; diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs deleted file mode 100644 index fd54f57..0000000 --- a/core/src/test_fft64/glwe_fourier.rs +++ /dev/null @@ -1,478 +0,0 @@ -use crate::{ - FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, div_ceil, - test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, -}; -use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; -use sampling::source::Source; - -#[test] -fn keyswitch() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - println!( - "test keyswitch digits: {} rank_in: {} rank_out: {}", - di, rank_in, rank_out - ); - let k_out: usize = k_ksk; // Better capture noise. - test_keyswitch(log_n, basek, k_in, k_out, k_ksk, di, rank_in, rank_out, 3.2); - }) - }); - }); -} - -#[test] -fn keyswitch_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 45; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!("test keyswitch_inplace digits: {} rank: {}", di, rank); - test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); - }); - }); -} - -#[test] -fn external_product() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; - println!("test external_product digits: {} rank: {}", di, rank); - let k_out: usize = k_ggsw; // Better capture noise. - test_external_product(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); - }); - }); -} - -#[test] -fn external_product_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; - println!("test external_product digits: {} rank: {}", di, rank); - test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); - }); - }); -} - -fn test_keyswitch( - log_n: usize, - basek: usize, - k_in: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ksk: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); - let mut ct_glwe_dft_in: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank_in); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut ct_glwe_dft_out: FourierGLWECiphertext, FFT64> = - FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) - | GLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_in) - | FourierGLWECiphertext::keyswitch_scratch_space( - &module, - basek, - ct_glwe_out.k(), - ksk.k(), - ct_glwe_in.k(), - digits, - rank_in, - rank_out, - ), - ); - - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ksk.generate_from_sk( - &module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe_in.encrypt_sk( - &module, - &pt_want, - &sk_in, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe_in.dft(&module, &mut ct_glwe_dft_in); - ct_glwe_dft_out.keyswitch(&module, &ct_glwe_dft_in, &ksk, scratch.borrow()); - ct_glwe_dft_out.idft(&module, &mut ct_glwe_out, scratch.borrow()); - - ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank_in as f64, - k_in, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | FourierGLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), digits, rank), - ); - - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ksk.generate_from_sk( - &module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.encrypt_sk( - &module, - &pt_want, - &sk_in, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.keyswitch_inplace(&module, &ksk, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_glwe, scratch.borrow()); - - ct_glwe.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_ct, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_external_product( - log_n: usize, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); - let mut ct_in_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_out_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: i64 = 1; - - pt_rgsw.raw_mut()[0] = 1; // X^{0} - module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | FourierGLWECiphertext::external_product_scratch_space( - &module, - basek, - ct_out.k(), - ct_in.k(), - ct_ggsw.k(), - digits, - rank, - ), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.dft(&module, &mut ct_in_dft); - ct_out_dft.external_product(&module, &ct_in_dft, &ct_ggsw, scratch.borrow()); - ct_out_dft.idft(&module, &mut ct_out, scratch.borrow()); - - ct_out.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - - pt_want.rotate_inplace(&module, k); - pt_have.sub_inplace_ab(&module, &pt_want); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_in, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: i64 = 1; - - pt_rgsw.raw_mut()[0] = 1; // X^{0} - module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | FourierGLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), digits, rank), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct, scratch.borrow()); - - ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - - pt_want.rotate_inplace(&module, k); - pt_have.sub_inplace_ab(&module, &pt_want); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_ct, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); - - println!("{} {}", noise_have, noise_want); -} From 992cb3fa37825fbb2ee8441c2338e6133305d566 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 8 Jul 2025 13:23:38 +0200 Subject: [PATCH 17/23] Added missing tests for CGGI & added standard blind rotation --- core/src/blind_rotation/ccgi.rs | 97 +++++++++++++++++++++- core/src/blind_rotation/key.rs | 5 ++ core/src/blind_rotation/test_fft64/cggi.rs | 64 +++++++------- 3 files changed, 130 insertions(+), 36 deletions(-) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 684ef72..2b5e877 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -6,7 +6,7 @@ use backend::{ use itertools::izip; use crate::{ - GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, + GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, lwe::ciphertext::LWECiphertextToRef, }; @@ -63,7 +63,7 @@ pub fn cggi_blind_rotate( } else if brk.block_size() > 1 { cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); } else { - todo!("implement this case") + cggi_blind_rotate_standard(module, res, lwe, lut, brk, scratch); } } @@ -121,8 +121,7 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( a.chunks_exact(block_size), brk.data.chunks_exact(block_size) ) - .enumerate() - .for_each(|(i, (ai, ski))| { + .for_each(|(ai, ski)| { (0..extension_factor).for_each(|i| { (0..cols).for_each(|j| { module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j); @@ -323,6 +322,96 @@ pub(crate) fn cggi_blind_rotate_block_binary( }); } +pub(crate) fn cggi_blind_rotate_standard( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, +{ + #[cfg(debug_assertions)] + { + assert_eq!( + res.n(), + module.n(), + "res.n(): {} != brk.n(): {}", + res.n(), + module.n() + ); + assert_eq!( + lut.domain_size(), + module.n(), + "lut.n(): {} != brk.n(): {}", + lut.domain_size(), + module.n() + ); + assert_eq!( + brk.n(), + module.n(), + "brk.n(): {} != brk.n(): {}", + brk.n(), + module.n() + ); + assert_eq!( + res.rank(), + brk.rank(), + "res.rank(): {} != brk.rank(): {}", + res.rank(), + brk.rank() + ); + assert_eq!( + lwe.n(), + brk.data.len(), + "lwe.n(): {} != brk.data.len(): {}", + lwe.n(), + brk.data.len() + ); + } + + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); + let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let basek: usize = brk.basek(); + + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref); + + let a: &[i64] = &lwe_2n[1..]; + let b: i64 = lwe_2n[0]; + + out_mut.data.zero(); + + // Initialize out to X^{b} * LUT(X) + module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); + + // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] + let (mut acc_tmp, scratch1) = scratch.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + let (mut acc_tmp_rot, scratch2) = scratch1.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + + // TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs + // TODO: first iteration can be optimized to be a gglwe product + izip!(a.iter(), brk.data.iter()).for_each(|(ai, ski)| { + // acc_tmp = sk[i] * acc + acc_tmp.external_product(module, &out_mut, ski, scratch2); + + // acc_tmp = (sk[i] * acc) * X^{ai} + acc_tmp_rot.rotate(module, *ai, &acc_tmp); + + // acc = acc + (sk[i] * acc) * X^{ai} + out_mut.add_inplace(module, &acc_tmp_rot); + + // acc = acc + (sk[i] * acc) * X^{ai} - (sk[i] * acc) = acc + (sk[i] * acc) * (X^{ai} - 1) + out_mut.sub_inplace_ab(module, &acc_tmp); + }); + + // We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}] + // on top of each others, thus ~ 2^{63-basek} additions are supported before overflow. + out_mut.normalize_inplace(module, scratch2); +} + pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { let basek: usize = lwe.basek(); diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index b7f9c3f..b83d60c 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -74,6 +74,11 @@ impl BlindRotationKeyCGGI { } } + #[allow(dead_code)] + pub(crate) fn n(&self) -> usize { + self.data[0].n() + } + #[allow(dead_code)] pub(crate) fn rows(&self) -> usize { self.data[0].rows() diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index 4a5c319..785246e 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -1,6 +1,4 @@ -use std::time::Instant; - -use backend::{Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxView}; +use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView}; use sampling::source::Source; use crate::{ @@ -14,22 +12,31 @@ use crate::{ }; #[test] -fn blind_rotation() { - let module: Module = Module::::new(2048); - let basek: usize = 19; +fn standard() { + blind_rotatio_test(224, 1, 1); +} - let n_lwe: usize = 1071; +#[test] +fn block_binary() { + blind_rotatio_test(224, 7, 1); +} + +#[test] +fn block_binary_extended() { + blind_rotatio_test(224, 7, 2); +} + +fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) { + let module: Module = Module::::new(512); + let basek: usize = 19; let k_lwe: usize = 24; let k_brk: usize = 3 * basek; - let rows_brk: usize = 1; + let rows_brk: usize = 2; // Ensures first limb is noise-free. let k_lut: usize = 2 * basek; let rank: usize = 1; - let block_size: usize = 7; - let extension_factor: usize = 2; - - let message_modulus: usize = 1 << 6; + let message_modulus: usize = 1 << 4; let mut source_xs: Source = Source::new([1u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); @@ -56,7 +63,6 @@ fn blind_rotation() { rank, )); - let start: Instant = Instant::now(); let mut brk: BlindRotationKeyCGGI = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); brk.generate_from_sk( @@ -69,9 +75,6 @@ fn blind_rotation() { scratch.borrow(), ); - let duration: std::time::Duration = start.elapsed(); - println!("brk-gen: {} ms", duration.as_millis()); - let mut lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe); let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); @@ -81,13 +84,13 @@ fn blind_rotation() { pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); - println!("{}", pt_lwe.data); + // println!("{}", pt_lwe.data); lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2); lwe.decrypt(&mut pt_lwe, &sk_lwe); - println!("{}", pt_lwe.data); + // println!("{}", pt_lwe.data); let mut f: Vec = vec![0i64; message_modulus]; f.iter_mut() @@ -99,13 +102,9 @@ fn blind_rotation() { let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); - let start: Instant = Instant::now(); - (0..32).for_each(|_| { - cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); - }); + cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); - let duration: std::time::Duration = start.elapsed(); - println!("blind-rotate: {} ms", duration.as_millis()); + println!("out_mut.data: {}", res.data); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); @@ -125,20 +124,21 @@ fn blind_rotation() { .sum::()) % (2 * lut.domain_size()) as i64; - println!("pt_want: {}", pt_want); + // println!("pt_want: {}", pt_want); lut.rotate(pt_want); - lut.data.iter().for_each(|d| { - println!("{}", d); - }); + // lut.data.iter().for_each(|d| { + // println!("{}", d); + // }); // First limb should be exactly equal (test are parameterized such that the noise does not reach // the first limb) - // assert_eq!(pt_have.data.at_mut(0, 0), lut.data[0].at_mut(0, 0)); + assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0)); // Then checks the noise - module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0); - let noise: f64 = lut.data[0].std(0, basek); - println!("noise: {}", noise); + // module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0); + // let noise: f64 = lut.data[0].std(0, basek); + // println!("noise: {}", noise); + // assert!(noise < 1e-3); } From f7c94cd84acc80c62cd58916c9678c3d4cbe7077 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 8 Jul 2025 13:37:35 +0200 Subject: [PATCH 18/23] fixed standard binary cggi blind rotation & fixed GLWECiphertext::external_product_scratch_space returning too small values --- core/src/blind_rotation/ccgi.rs | 71 +++++++++++++--------- core/src/blind_rotation/test_fft64/cggi.rs | 10 +-- core/src/glwe/external_product.rs | 2 +- 3 files changed, 50 insertions(+), 33 deletions(-) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 2b5e877..0d4f6fd 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -8,43 +8,50 @@ use itertools::izip; use crate::{ GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, + dist::Distribution, lwe::ciphertext::LWECiphertextToRef, }; pub fn cggi_blind_rotate_scratch_space( module: &Module, + block_size: usize, extension_factor: usize, basek: usize, - k_lut: usize, + k_res: usize, k_brk: usize, rows: usize, rank: usize, ) -> usize { - let cols: usize = rank + 1; let brk_size: usize = k_brk.div_ceil(basek); - let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor; - let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); - let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; - let acc_dft_add: usize = vmp_res; - let xai_plus_y: usize = module.bytes_of_scalar_znx(1); - let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); - let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + if block_size > 1 { + let cols: usize = rank + 1; + let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor; + let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); + let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; + let acc_dft_add: usize = vmp_res; + let xai_plus_y: usize = module.bytes_of_scalar_znx(1); + let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); + let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) - let acc: usize; - if extension_factor > 1 { - acc = module.bytes_of_vec_znx(cols, k_lut.div_ceil(basek)) * extension_factor; + let acc: usize; + if extension_factor > 1 { + acc = module.bytes_of_vec_znx(cols, k_res.div_ceil(basek)) * extension_factor; + } else { + acc = 0; + } + + return acc + + acc_dft + + acc_dft_add + + vmp_res + + xai_plus_y + + xai_plus_y_dft + + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()))); } else { - acc = 0; + 2 * GLWECiphertext::bytes_of(module, basek, k_res, rank) + + GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank) } - - return acc - + acc_dft - + acc_dft_add - + vmp_res - + xai_plus_y - + xai_plus_y_dft - + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()))); } pub fn cggi_blind_rotate( @@ -58,12 +65,20 @@ pub fn cggi_blind_rotate( DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, { - if lut.extension_factor() > 1 { - cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch); - } else if brk.block_size() > 1 { - cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); - } else { - cggi_blind_rotate_standard(module, res, lwe, lut, brk, scratch); + match brk.dist { + Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => { + if lut.extension_factor() > 1 { + cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch); + } else if brk.block_size() > 1 { + cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); + } else { + cggi_blind_rotate_binary_standard(module, res, lwe, lut, brk, scratch); + } + } + // TODO: ternary distribution ? + _ => panic!( + "invalid BlindRotationKeyCGGI distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" + ), } } @@ -322,7 +337,7 @@ pub(crate) fn cggi_blind_rotate_block_binary( }); } -pub(crate) fn cggi_blind_rotate_standard( +pub(crate) fn cggi_blind_rotate_binary_standard( module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index 785246e..ea8291a 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -33,7 +33,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) let k_lwe: usize = 24; let k_brk: usize = 3 * basek; let rows_brk: usize = 2; // Ensures first limb is noise-free. - let k_lut: usize = 2 * basek; + let k_lut: usize = 1 * basek; + let k_res: usize = 2 * basek; let rank: usize = 1; let message_modulus: usize = 1 << 4; @@ -55,9 +56,10 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space( &module, + block_size, extension_factor, basek, - k_lut, + k_res, k_brk, rows_brk, rank, @@ -100,13 +102,13 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); lut.set(&module, &f, message_modulus); - let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); + let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_res, rank); cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); println!("out_mut.data: {}", res.data); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_lut); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_res); res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); diff --git a/core/src/glwe/external_product.rs b/core/src/glwe/external_product.rs index 3ebf339..e7ee778 100644 --- a/core/src/glwe/external_product.rs +++ b/core/src/glwe/external_product.rs @@ -14,7 +14,7 @@ impl GLWECiphertext> { digits: usize, rank: usize, ) -> usize { - let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_ggsw, rank); let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); let out_size: usize = k_out.div_ceil(basek); let ggsw_size: usize = k_ggsw.div_ceil(basek); From af5bbbb55db93cd815afe232bb7d588069d80edd Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 8 Jul 2025 13:48:51 +0200 Subject: [PATCH 19/23] fixed modulus switching rounding --- core/src/blind_rotation/ccgi.rs | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 0d4f6fd..62ec4ac 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -438,7 +438,7 @@ pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphe if basek > log2n { let diff: usize = basek - log2n; res.iter_mut().for_each(|x| { - *x = div_ceil_signed_by_pow2(x, diff); + *x = div_round_by_pow2(x, diff); }) } else { let rem: usize = basek - (log2n % basek); @@ -460,21 +460,5 @@ pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphe #[inline(always)] fn div_round_by_pow2(x: &i64, k: usize) -> i64 { - if x >= &0 { - (x + (1 << (k - 1))) >> k - } else { - (x + (-1 << (k - 1))) >> k - } -} - -// #[inline(always)] -// fn div_floor_signed_by_pow2(x: &i64, k: usize) -> i64{ -// let bias: i64 = (1 << k) - 1; -// (x + ((x >> 63) & bias)) >> k -// } - -#[inline(always)] -fn div_ceil_signed_by_pow2(x: &i64, k: usize) -> i64 { - let bias: i64 = (1 << k) - 1; - (x + ((x >> 63) & bias)) >> k + (x + (1 << (k - 1))) >> k } From 2e0e7e11b43bd488afe7833e070ab1812a1b831e Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 8 Jul 2025 13:54:20 +0200 Subject: [PATCH 20/23] Enforce extension factor to be a power of two --- core/src/blind_rotation/ccgi.rs | 1 - core/src/blind_rotation/lut.rs | 8 ++++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 62ec4ac..3e0dc33 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -82,7 +82,6 @@ pub fn cggi_blind_rotate( } } -// TODO: ENSURE DOMAIN EXTENSION AS pub(crate) fn cggi_blind_rotate_block_binary_extended( module: &Module, res: &mut GLWECiphertext, diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index a7fe003..96c1422 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -8,6 +8,14 @@ pub struct LookUpTable { impl LookUpTable { pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self { + #[cfg(debug_assertions)] + { + assert!( + extension_factor & (extension_factor - 1) == 0, + "extension_factor must be a power of two but is: {}", + extension_factor + ); + } let size: usize = k.div_ceil(basek); let mut data: Vec>> = Vec::with_capacity(extension_factor); (0..extension_factor).for_each(|_| { From 0e65df979595d03fc3f1d364d5bafc7c94987220 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 8 Jul 2025 17:00:42 +0200 Subject: [PATCH 21/23] export blind rotation --- core/src/blind_rotation/{ccgi.rs => cggi.rs} | 0 core/src/blind_rotation/mod.rs | 7 +++++-- core/src/blind_rotation/test_fft64/cggi.rs | 2 +- core/src/lib.rs | 4 +++- 4 files changed, 9 insertions(+), 4 deletions(-) rename core/src/blind_rotation/{ccgi.rs => cggi.rs} (100%) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/cggi.rs similarity index 100% rename from core/src/blind_rotation/ccgi.rs rename to core/src/blind_rotation/cggi.rs diff --git a/core/src/blind_rotation/mod.rs b/core/src/blind_rotation/mod.rs index 1b3d5ff..bbbdd2c 100644 --- a/core/src/blind_rotation/mod.rs +++ b/core/src/blind_rotation/mod.rs @@ -1,7 +1,10 @@ -// pub mod cggi; -pub mod ccgi; +pub mod cggi; pub mod key; pub mod lut; +pub use cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space}; +pub use key::BlindRotationKeyCGGI; +pub use lut::LookUpTable; + #[cfg(test)] pub mod test_fft64; diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index ea8291a..e544494 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -4,7 +4,7 @@ use sampling::source::Source; use crate::{ FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, blind_rotation::{ - ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n}, + cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n}, key::BlindRotationKeyCGGI, lut::LookUpTable, }, diff --git a/core/src/lib.rs b/core/src/lib.rs index f94963c..9f8c114 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -11,14 +11,16 @@ pub mod noise; use backend::Backend; use backend::FFT64; use backend::Module; +pub use blind_rotation::{BlindRotationKeyCGGI, LookUpTable, cggi_blind_rotate, cggi_blind_rotate_scratch_space}; pub use elem::{GetRow, Infos, SetMetaData, SetRow}; pub use fourier_glwe::{FourierGLWECiphertext, FourierGLWESecret}; pub use gglwe::{GGLWECiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey}; pub use ggsw::GGSWCiphertext; pub use glwe::{GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWEPublicKey, GLWESecret}; -pub(crate) use glwe::{GLWECiphertextToMut, GLWECiphertextToRef}; pub use lwe::{LWECiphertext, LWESecret}; +pub(crate) use glwe::{GLWECiphertextToMut, GLWECiphertextToRef}; + pub use backend::Scratch; pub use backend::ScratchOwned; From 38df06f7abb4e1d028e5b1c6d76f1f0697e8a2f0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 8 Jul 2025 18:50:04 +0200 Subject: [PATCH 22/23] Fixed lut generation --- core/src/blind_rotation/lut.rs | 7 ++++++- core/src/blind_rotation/test_fft64/lut.rs | 12 ++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index 96c1422..7446e9a 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -40,8 +40,13 @@ impl LookUpTable { // Get the number minimum limb to store the message modulus let limbs: usize = k.div_ceil(1 << basek); + #[cfg(debug_assertions)] + { + assert!(limbs <= self.data[0].size()); + } + // Scaling factor - let scale: i64 = (1 << (basek * limbs - 1)).div_round(k) as i64; + let scale: i64 = 1 << (k % basek) as i64; // #elements in lookup table let f_len: usize = f.len(); diff --git a/core/src/blind_rotation/test_fft64/lut.rs b/core/src/blind_rotation/test_fft64/lut.rs index 3738b62..02f710d 100644 --- a/core/src/blind_rotation/test_fft64/lut.rs +++ b/core/src/blind_rotation/test_fft64/lut.rs @@ -12,7 +12,7 @@ fn standard() { let message_modulus: usize = 16; let extension_factor: usize = 1; - let scale: usize = (1 << (basek - 1)) / message_modulus; + let log_scale: usize = basek + 1; let mut f: Vec = vec![0i64; message_modulus]; f.iter_mut() @@ -20,7 +20,7 @@ fn standard() { .for_each(|(i, x)| *x = (i as i64) - 8); let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); - lut.set(&module, &f, message_modulus); + lut.set(&module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; lut.rotate(half_step); @@ -31,7 +31,7 @@ fn standard() { (0..step).for_each(|_| { assert_eq!( f[i / step] % message_modulus as i64, - lut.data[0].raw()[0] / scale as i64 + lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64 ); lut.rotate(-1); }); @@ -46,7 +46,7 @@ fn extended() { let message_modulus: usize = 16; let extension_factor: usize = 4; - let scale: usize = (1 << (basek - 1)) / message_modulus; + let log_scale: usize = basek + 1; let mut f: Vec = vec![0i64; message_modulus]; f.iter_mut() @@ -54,7 +54,7 @@ fn extended() { .for_each(|(i, x)| *x = (i as i64) - 8); let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); - lut.set(&module, &f, message_modulus); + lut.set(&module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; lut.rotate(half_step); @@ -65,7 +65,7 @@ fn extended() { (0..step).for_each(|_| { assert_eq!( f[i / step] % message_modulus as i64, - lut.data[0].raw()[0] / scale as i64 + lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64 ); lut.rotate(-1); }); From 52a6a130a595cd47023784d071e32d2e0105caa0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 11 Jul 2025 12:29:49 +0200 Subject: [PATCH 23/23] Fixes after meeting --- backend/src/lib.rs | 8 +- backend/src/vec_znx.rs | 8 +- core/src/blind_rotation/cggi.rs | 168 ++++++++++----------- core/src/blind_rotation/key.rs | 120 +++++++++++---- core/src/blind_rotation/lut.rs | 4 + core/src/blind_rotation/test_fft64/cggi.rs | 31 +--- 6 files changed, 188 insertions(+), 151 deletions(-) diff --git a/backend/src/lib.rs b/backend/src/lib.rs index e7c8e5e..8ac50ce 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -240,11 +240,11 @@ impl Scratch { ) -> (Vec>, &mut Self) { let mut scratch: &mut Scratch = self; let mut slice: Vec> = Vec::with_capacity(slice_size); - for _ in 0..slice_size{ + for _ in 0..slice_size { let (znx, new_scratch) = scratch.tmp_vec_znx_dft(module, cols, size); scratch = new_scratch; slice.push(znx); - }; + } (slice, scratch) } @@ -279,11 +279,11 @@ impl Scratch { ) -> (Vec>, &mut Self) { let mut scratch: &mut Scratch = self; let mut slice: Vec> = Vec::with_capacity(slice_size); - for _ in 0..slice_size{ + for _ in 0..slice_size { let (znx, new_scratch) = scratch.tmp_vec_znx(module, cols, size); scratch = new_scratch; slice.push(znx); - }; + } (slice, scratch) } diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index 00568dd..74f9f86 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -111,10 +111,10 @@ impl + AsRef<[u8]>> VecZnx { } } - pub fn rotate(&mut self, k: i64){ - unsafe{ - (0..self.cols()).for_each(|i|{ - (0..self.size()).for_each(|j|{ + pub fn rotate(&mut self, k: i64) { + unsafe { + (0..self.cols()).for_each(|i| { + (0..self.size()).for_each(|j| { znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j)); }); }) diff --git a/core/src/blind_rotation/cggi.rs b/core/src/blind_rotation/cggi.rs index 3e0dc33..2cf1d88 100644 --- a/core/src/blind_rotation/cggi.rs +++ b/core/src/blind_rotation/cggi.rs @@ -1,5 +1,5 @@ use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero, }; @@ -30,7 +30,7 @@ pub fn cggi_blind_rotate_scratch_space( let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; let acc_dft_add: usize = vmp_res; - let xai_plus_y: usize = module.bytes_of_scalar_znx(1); + let xai_plus_y: usize = module.bytes_of_scalar_znx_dft(1); let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) @@ -54,16 +54,17 @@ pub fn cggi_blind_rotate_scratch_space( } } -pub fn cggi_blind_rotate( +pub fn cggi_blind_rotate( module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, lut: &LookUpTable, - brk: &BlindRotationKeyCGGI, + brk: &BlindRotationKeyCGGI, scratch: &mut Scratch, ) where DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, + DataBrk: AsRef<[u8]>, { match brk.dist { Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => { @@ -82,16 +83,17 @@ pub fn cggi_blind_rotate( } } -pub(crate) fn cggi_blind_rotate_block_binary_extended( +pub(crate) fn cggi_blind_rotate_block_binary_extended( module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, lut: &LookUpTable, - brk: &BlindRotationKeyCGGI, + brk: &BlindRotationKeyCGGI, scratch: &mut Scratch, ) where DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, + DataBrk: AsRef<[u8]>, { let extension_factor: usize = lut.extension_factor(); let basek: usize = res.basek(); @@ -102,25 +104,35 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows); let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size()); let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size()); - let (mut xai_plus_y, scratch5) = scratch4.tmp_scalar_znx(module, 1); + let (mut minus_one, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1); + minus_one.raw_mut()[..module.n() >> 1].fill(-1.0); + (0..extension_factor).for_each(|i| { acc[i].zero(); }); + let x_pow_a: &Vec, FFT64>>; + if let Some(b) = &brk.x_pow_a { + x_pow_a = b + } else { + panic!("invalid key: x_pow_a has not been initialized") + } + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let two_n: usize = 2 * module.n(); let two_n_ext: usize = 2 * lut.domain_size(); negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref); let a: &[i64] = &lwe_2n[1..]; - let b_pos: usize = ((lwe_2n[0] + two_n_ext as i64) % two_n_ext as i64) as usize; + let b_pos: usize = ((lwe_2n[0] + two_n_ext as i64) & (two_n_ext - 1) as i64) as usize; let b_hi: usize = b_pos / extension_factor; - let b_lo: usize = b_pos % extension_factor; + let b_lo: usize = b_pos & (extension_factor - 1); for (i, j) in (0..b_lo).zip(extension_factor - b_lo..extension_factor) { module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i], 0, &lut.data[j], 0); @@ -145,9 +157,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( // TODO: first & last iterations can be optimized izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { - let ai_pos: usize = ((aii + two_n_ext as i64) % two_n_ext as i64) as usize; + let ai_pos: usize = ((aii + two_n_ext as i64) & (two_n_ext - 1) as i64) as usize; let ai_hi: usize = ai_pos / extension_factor; - let ai_lo: usize = ai_pos % extension_factor; + let ai_lo: usize = ai_pos & (extension_factor - 1); // vmp_res = DFT(acc) * BRK[i] (0..extension_factor).for_each(|i| { @@ -156,48 +168,62 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( // Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1) if ai_lo == 0 { - // DFT X^{-ai} - set_xai_plus_y(module, ai_hi, -1, &mut xai_plus_y_dft, &mut xai_plus_y); - // Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1) - (0..extension_factor).for_each(|j| { - (0..cols).for_each(|i| { - module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i); + if ai_hi != 0 { + // DFT X^{-ai} + module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_hi], 0, &minus_one, 0); + (0..extension_factor).for_each(|j| { + (0..cols).for_each(|i| { + module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i); + }); }); - }); + } + // Non trivial case: rotation between polynomials // In this case we can't directly multiply with (X^{-ai} - 1) because of the // ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the // computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai} } else { // Sets acc_add_dft[i] = acc[i] * sk - (0..extension_factor).for_each(|i| { - (0..cols).for_each(|k| { - module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); - }) - }); - // DFT X^{-ai} - set_xai_plus_y(module, ai_hi + 1, 0, &mut xai_plus_y_dft, &mut xai_plus_y); - - // Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1} - for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) { - (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); - }); + // Sets acc_add_dft[0..ai_lo] -= acc[..ai_lo] * sk + if (ai_hi + 1) & (two_n - 1) != 0 { + for i in 0..ai_lo { + (0..cols).for_each(|k| { + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); + }); + } } - // DFT X^{-ai} - set_xai_plus_y(module, ai_hi, 0, &mut xai_plus_y_dft, &mut xai_plus_y); + // Sets acc_add_dft[ai_lo..extension_factor] -= acc[ai_lo..extension_factor] * sk + if ai_hi != 0 { + for i in ai_lo..extension_factor { + (0..cols).for_each(|k: usize| { + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); + }); + } + } + + // Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1} + if (ai_hi + 1) & (two_n - 1) != 0 { + for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) { + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi + 1], 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); + }); + } + } // Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai} - for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) { - (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0); - module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); - }); + if ai_hi != 0 { + // Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai} + for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) { + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi], 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); + }); + } } } }); @@ -220,49 +246,17 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended( }); } -fn set_xai_plus_y( - module: &Module, - ai: usize, - y: i64, - res: &mut ScalarZnxDft<&mut [u8], FFT64>, - buf: &mut ScalarZnx<&mut [u8]>, -) { - let n: usize = module.n(); - - { - let raw: &mut [i64] = buf.at_mut(0, 0); - if ai < n { - raw[ai] = 1; - } else { - raw[(ai - n) & (n - 1)] = -1; - } - raw[0] += y; - } - - module.svp_prepare(res, 0, buf, 0); - - { - let raw: &mut [i64] = buf.at_mut(0, 0); - - if ai < n { - raw[ai] = 0; - } else { - raw[(ai - n) & (n - 1)] = 0; - } - raw[0] = 0; - } -} - -pub(crate) fn cggi_blind_rotate_block_binary( +pub(crate) fn cggi_blind_rotate_block_binary( module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, lut: &LookUpTable, - brk: &BlindRotationKeyCGGI, + brk: &BlindRotationKeyCGGI, scratch: &mut Scratch, ) where DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, + DataBrk: AsRef<[u8]>, { let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); @@ -290,9 +284,18 @@ pub(crate) fn cggi_blind_rotate_block_binary( let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows); let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size()); let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size()); - let (mut xai_plus_y, scratch4) = scratch3.tmp_scalar_znx(module, 1); + let (mut minus_one, scratch4) = scratch3.tmp_scalar_znx_dft(module, 1); let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); + minus_one.raw_mut()[..module.n() >> 1].fill(-1.0); + + let x_pow_a: &Vec, FFT64>>; + if let Some(b) = &brk.x_pow_a { + x_pow_a = b + } else { + panic!("invalid key: x_pow_a has not been initialized") + } + izip!( a.chunks_exact(block_size), brk.data.chunks_exact(block_size) @@ -305,13 +308,13 @@ pub(crate) fn cggi_blind_rotate_block_binary( acc_add_dft.zero(); izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { - let ai_pos: usize = ((aii + two_n as i64) % two_n as i64) as usize; + let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize; // vmp_res = DFT(acc) * BRK[i] module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5); // DFT(X^ai -1) - set_xai_plus_y(module, ai_pos, -1, &mut xai_plus_y_dft, &mut xai_plus_y); + module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_pos], 0, &minus_one, 0); // DFT(X^ai -1) * (DFT(acc) * BRK[i]) (0..cols).for_each(|i| { @@ -320,10 +323,6 @@ pub(crate) fn cggi_blind_rotate_block_binary( }); }); - (0..cols).for_each(|i| { - module.vec_znx_dft_add_inplace(&mut acc_dft, i, &acc_add_dft, i); - }); - { let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size()); @@ -336,16 +335,17 @@ pub(crate) fn cggi_blind_rotate_block_binary( }); } -pub(crate) fn cggi_blind_rotate_binary_standard( +pub(crate) fn cggi_blind_rotate_binary_standard( module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, lut: &LookUpTable, - brk: &BlindRotationKeyCGGI, + brk: &BlindRotationKeyCGGI, scratch: &mut Scratch, ) where DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, + DataBrk: AsRef<[u8]>, { #[cfg(debug_assertions)] { diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index b83d60c..01511c3 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -1,11 +1,15 @@ -use backend::{Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}; +use backend::{ + Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxToRef, Scratch, + ZnxView, ZnxViewMut, +}; use sampling::source::Source; use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, Infos, LWESecret}; -pub struct BlindRotationKeyCGGI { - pub(crate) data: Vec, B>>, +pub struct BlindRotationKeyCGGI { + pub(crate) data: Vec>, pub(crate) dist: Distribution, + pub(crate) x_pow_a: Option, B>>>, } // pub struct BlindRotationKeyFHEW { @@ -13,20 +17,61 @@ pub struct BlindRotationKeyCGGI { // pub(crate) auto: Vec, B>>, //} -impl BlindRotationKeyCGGI { +impl BlindRotationKeyCGGI, FFT64> { pub fn allocate(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { let mut data: Vec, FFT64>> = Vec::with_capacity(n_lwe); (0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank))); Self { data, dist: Distribution::NONE, + x_pow_a: None::, FFT64>>>, } } pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) } +} +impl> BlindRotationKeyCGGI { + #[allow(dead_code)] + pub(crate) fn n(&self) -> usize { + self.data[0].n() + } + + #[allow(dead_code)] + pub(crate) fn rows(&self) -> usize { + self.data[0].rows() + } + + #[allow(dead_code)] + pub(crate) fn k(&self) -> usize { + self.data[0].k() + } + + #[allow(dead_code)] + pub(crate) fn size(&self) -> usize { + self.data[0].size() + } + + #[allow(dead_code)] + pub(crate) fn rank(&self) -> usize { + self.data[0].rank() + } + + pub(crate) fn basek(&self) -> usize { + self.data[0].basek() + } + + pub(crate) fn block_size(&self) -> usize { + match self.dist { + Distribution::BinaryBlock(value) => value, + _ => 1, + } + } +} + +impl + AsMut<[u8]>> BlindRotationKeyCGGI { pub fn generate_from_sk( &mut self, module: &Module, @@ -64,42 +109,51 @@ impl BlindRotationKeyCGGI { self.data.iter_mut().enumerate().for_each(|(i, ggsw)| { pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch); - }) - } + }); - pub(crate) fn block_size(&self) -> usize { - match self.dist { - Distribution::BinaryBlock(value) => value, - _ => 1, + match sk_lwe.dist { + Distribution::BinaryBlock(_) => { + let mut x_pow_a: Vec, FFT64>> = Vec::with_capacity(module.n() << 1); + let mut buf: ScalarZnx> = module.new_scalar_znx(1); + (0..module.n() << 1).for_each(|i| { + let mut res: ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); + set_xai_plus_y(module, i, 0, &mut res, &mut buf); + x_pow_a.push(res); + }); + self.x_pow_a = Some(x_pow_a); + } + _ => {} } } +} - #[allow(dead_code)] - pub(crate) fn n(&self) -> usize { - self.data[0].n() +pub fn set_xai_plus_y(module: &Module, ai: usize, y: i64, res: &mut ScalarZnxDft, buf: &mut ScalarZnx) +where + A: AsRef<[u8]> + AsMut<[u8]>, + B: AsRef<[u8]> + AsMut<[u8]>, +{ + let n: usize = module.n(); + + { + let raw: &mut [i64] = buf.at_mut(0, 0); + if ai < n { + raw[ai] = 1; + } else { + raw[(ai - n) & (n - 1)] = -1; + } + raw[0] += y; } - #[allow(dead_code)] - pub(crate) fn rows(&self) -> usize { - self.data[0].rows() - } + module.svp_prepare(res, 0, buf, 0); - #[allow(dead_code)] - pub(crate) fn k(&self) -> usize { - self.data[0].k() - } + { + let raw: &mut [i64] = buf.at_mut(0, 0); - #[allow(dead_code)] - pub(crate) fn size(&self) -> usize { - self.data[0].size() - } - - #[allow(dead_code)] - pub(crate) fn rank(&self) -> usize { - self.data[0].rank() - } - - pub(crate) fn basek(&self) -> usize { - self.data[0].basek() + if ai < n { + raw[ai] = 0; + } else { + raw[(ai - n) & (n - 1)] = 0; + } + raw[0] = 0; } } diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index 7446e9a..300aa6e 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -24,6 +24,10 @@ impl LookUpTable { Self { data, basek, k } } + pub fn log_extension_factor(&self) -> usize { + (usize::BITS - (self.extension_factor() - 1).leading_zeros()) as _ + } + pub fn extension_factor(&self) -> usize { self.data.len() } diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index e544494..2fbad48 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -39,8 +39,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) let message_modulus: usize = 1 << 4; - let mut source_xs: Source = Source::new([1u8; 32]); - let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xs: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([2u8; 32]); let mut source_xa: Source = Source::new([1u8; 32]); let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); @@ -65,7 +65,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) rank, )); - let mut brk: BlindRotationKeyCGGI = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); + let mut brk: BlindRotationKeyCGGI, FFT64> = + BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); brk.generate_from_sk( &module, @@ -86,14 +87,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); - // println!("{}", pt_lwe.data); - lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2); - lwe.decrypt(&mut pt_lwe, &sk_lwe); - - // println!("{}", pt_lwe.data); - let mut f: Vec = vec![0i64; message_modulus]; f.iter_mut() .enumerate() @@ -106,14 +101,10 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); - println!("out_mut.data: {}", res.data); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_res); res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); - println!("pt_have: {}", pt_have.data); - let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref()); @@ -124,23 +115,11 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) .zip(sk_lwe.data.at(0, 0)) .map(|(x, y)| x * y) .sum::()) - % (2 * lut.domain_size()) as i64; - - // println!("pt_want: {}", pt_want); + & (2 * lut.domain_size() - 1) as i64; lut.rotate(pt_want); - // lut.data.iter().for_each(|d| { - // println!("{}", d); - // }); - // First limb should be exactly equal (test are parameterized such that the noise does not reach // the first limb) assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0)); - - // Then checks the noise - // module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0); - // let noise: f64 = lut.data[0].std(0, basek); - // println!("noise: {}", noise); - // assert!(noise < 1e-3); }