diff --git a/backend/spqlios-arithmetic b/backend/spqlios-arithmetic index 173b980..0ae9a7b 160000 --- a/backend/spqlios-arithmetic +++ b/backend/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 173b980c7b8a4f0523d04c2aed061c2e046e846c +Subproject commit 0ae9a7b5adf07ce0b1797562528dab8e28192238 diff --git a/backend/src/encoding.rs b/backend/src/encoding.rs index 73b86a3..55bba09 100644 --- a/backend/src/encoding.rs +++ b/backend/src/encoding.rs @@ -157,18 +157,22 @@ fn decode_vec_i64>(a: &VecZnx, col_i: usize, basek: usize, k: } data.copy_from_slice(a.at(col_i, 0)); let rem: usize = basek - (k % basek); - (1..size).for_each(|i| { - if i == size - 1 && rem != basek { - let k_rem: usize = basek - rem; - izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << k_rem) + (x >> rem); - }); - } else { - izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << basek) + x; - }); - } - }) + if k < basek { + data.iter_mut().for_each(|x| *x >>= rem); + } else { + (1..size).for_each(|i| { + if i == size - 1 && rem != basek { + let k_rem: usize = basek - rem; + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << basek) + x; + }); + } + }) + } } fn decode_vec_float>(a: &VecZnx, col_i: usize, basek: usize, data: &mut [Float]) { @@ -268,7 +272,7 @@ fn decode_coeff_i64>(a: &VecZnx, col_i: usize, basek: usize, k let mut res: i64 = data[i]; let rem: usize = basek - (k % basek); let slice_size: usize = a.n() * a.cols(); - (1..size).for_each(|i| { + (0..size).for_each(|i| { let x: i64 = data[i * slice_size]; if i == size - 1 && rem != basek { let k_rem: usize = basek - rem; @@ -316,18 +320,25 @@ mod tests { let module: Module = Module::::new(n); let basek: usize = 17; let size: usize = 5; - let k: usize = size * basek - 5; - let mut a: VecZnx<_> = module.new_vec_znx(2, size); - let mut source = Source::new([0u8; 32]); - let raw: &mut [i64] = a.raw_mut(); - raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut().for_each(|x| *x = source.next_i64()); - a.encode_vec_i64(col_i, basek, k, &have, 64); - let mut want = vec![i64::default(); n]; - a.decode_vec_i64(col_i, basek, k, &mut want); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - }) + for k in [size * basek - 5] { + let mut a: VecZnx<_> = module.new_vec_znx(2, size); + let mut source = Source::new([0u8; 32]); + let raw: &mut [i64] = a.raw_mut(); + raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + (0..a.cols()).for_each(|col_i| { + let mut have: Vec = vec![i64::default(); n]; + have.iter_mut().for_each(|x| { + if k < 64 { + *x = source.next_u64n(1 << k, (1 << k) - 1) as i64; + } else { + *x = source.next_i64(); + } + }); + a.encode_vec_i64(col_i, basek, k, &have, std::cmp::min(k, 64)); + let mut want = vec![i64::default(); n]; + a.decode_vec_i64(col_i, basek, k, &mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + }) + } } } diff --git a/backend/src/lib.rs b/backend/src/lib.rs index dcf4325..8ac50ce 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -103,7 +103,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { assert_eq!( - (size * size_of::()) % align, + (size * size_of::()) % (align / size_of::()), 0, "size={} must be a multiple of align={}", size, @@ -121,7 +121,7 @@ pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { /// of [DEFAULTALIGN]/size_of::() that is equal or greater to `size`. pub fn alloc_aligned(size: usize) -> Vec { alloc_aligned_custom::( - size + (size % (DEFAULTALIGN / size_of::())), + size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::()))), DEFAULTALIGN, ) } @@ -231,6 +231,23 @@ impl Scratch { ) } + pub fn tmp_slice_vec_znx_dft( + &mut self, + slice_size: usize, + module: &Module, + cols: usize, + size: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut slice: Vec> = Vec::with_capacity(slice_size); + for _ in 0..slice_size { + let (znx, new_scratch) = scratch.tmp_vec_znx_dft(module, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } + pub fn tmp_vec_znx_big( &mut self, module: &Module, @@ -253,6 +270,23 @@ impl Scratch { ) } + pub fn tmp_slice_vec_znx( + &mut self, + slice_size: usize, + module: &Module, + cols: usize, + size: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut slice: Vec> = Vec::with_capacity(slice_size); + for _ in 0..slice_size { + let (znx, new_scratch) = scratch.tmp_vec_znx(module, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } + pub fn tmp_mat_znx_dft( &mut self, module: &Module, diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs index 9656dfb..b48cb1a 100644 --- a/backend/src/mat_znx_dft_ops.rs +++ b/backend/src/mat_znx_dft_ops.rs @@ -2,8 +2,8 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, - VecZnxDftToRef, + Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, + ScalarZnxDftOps, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, }; pub trait MatZnxDftAlloc { @@ -38,6 +38,8 @@ pub trait MatZnxDftScratch { b_cols_out: usize, b_size: usize, ) -> usize; + + fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize; } /// This trait implements methods for vector matrix product, @@ -52,7 +54,7 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to prepare. /// /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + fn mat_znx_dft_set_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) where R: MatZnxToMut, A: VecZnxDftToRef; @@ -64,11 +66,28 @@ pub trait MatZnxDftOps { /// * `res`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + fn mat_znx_dft_get_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) where R: VecZnxDftToMut, A: MatZnxToRef; + /// Multiplies A by (X^{k} - 1) and stores the result on R. + fn mat_znx_dft_mul_x_pow_minus_one(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: MatZnxToMut, + A: MatZnxToRef; + + /// Multiplies A by (X^{k} - 1). + fn mat_znx_dft_mul_x_pow_minus_one_inplace(&self, k: i64, a: &mut A, scratch: &mut Scratch) + where + A: MatZnxToMut; + + /// Multiplies A by (X^{k} - 1). + fn mat_znx_dft_mul_x_pow_minus_one_add_inplace(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: MatZnxToMut, + A: MatZnxToRef; + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// @@ -149,10 +168,142 @@ impl MatZnxDftScratch for Module { ) as usize } } + + fn mat_znx_dft_mul_x_pow_minus_one_scratch_space(&self, size: usize, cols_out: usize) -> usize { + let xpm1_dft: usize = self.bytes_of_scalar_znx(1); + let xpm1: usize = self.bytes_of_scalar_znx_dft(1); + let tmp: usize = self.bytes_of_vec_znx_dft(cols_out, size); + xpm1_dft + (xpm1 | 2 * tmp) + } } impl MatZnxDftOps for Module { - fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + fn mat_znx_dft_mul_x_pow_minus_one(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: MatZnxToMut, + A: MatZnxToRef, + { + let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: MatZnxDft<&[u8], FFT64> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!(res.rows(), a.rows()); + assert_eq!(res.cols_in(), a.cols_in()); + assert_eq!(res.cols_out(), a.cols_out()); + } + + let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1); + + { + let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1); + xpm1.data[0] = 1; + self.vec_znx_rotate_inplace(k, &mut xpm1, 0); + self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0); + } + + let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, res.cols_out(), res.size()); + let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, res.cols_out(), res.size()); + + (0..res.rows()).for_each(|row_i| { + (0..res.cols_in()).for_each(|col_j| { + self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j); + + (0..tmp_0.cols()).for_each(|i| { + self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i); + self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i); + }); + + self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_1); + }); + }); + } + + fn mat_znx_dft_mul_x_pow_minus_one_inplace(&self, k: i64, a: &mut A, scratch: &mut Scratch) + where + A: MatZnxToMut, + { + let mut a: MatZnxDft<&mut [u8], FFT64> = a.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + + let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1); + + { + let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1); + xpm1.data[0] = 1; + self.vec_znx_rotate_inplace(k, &mut xpm1, 0); + self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0); + } + + let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size()); + let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size()); + + (0..a.rows()).for_each(|row_i| { + (0..a.cols_in()).for_each(|col_j| { + self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j); + + (0..tmp_0.cols()).for_each(|i| { + self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i); + self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i); + }); + + self.mat_znx_dft_set_row(&mut a, row_i, col_j, &tmp_1); + }); + }); + } + + fn mat_znx_dft_mul_x_pow_minus_one_add_inplace(&self, k: i64, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: MatZnxToMut, + A: MatZnxToRef, + { + let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: MatZnxDft<&[u8], FFT64> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + + let (mut xpm1_dft, scratch1) = scratch.tmp_scalar_znx_dft(self, 1); + + { + let (mut xpm1, _) = scratch1.tmp_scalar_znx(self, 1); + xpm1.data[0] = 1; + self.vec_znx_rotate_inplace(k, &mut xpm1, 0); + self.svp_prepare(&mut xpm1_dft, 0, &xpm1, 0); + } + + let (mut tmp_0, scratch2) = scratch1.tmp_vec_znx_dft(self, a.cols_out(), a.size()); + let (mut tmp_1, _) = scratch2.tmp_vec_znx_dft(self, a.cols_out(), a.size()); + + (0..a.rows()).for_each(|row_i| { + (0..a.cols_in()).for_each(|col_j| { + self.mat_znx_dft_get_row(&mut tmp_0, &a, row_i, col_j); + + (0..tmp_0.cols()).for_each(|i| { + self.svp_apply(&mut tmp_1, i, &xpm1_dft, 0, &tmp_0, i); + self.vec_znx_dft_sub_ab_inplace(&mut tmp_1, i, &tmp_0, i); + }); + + self.mat_znx_dft_get_row(&mut tmp_0, &res, row_i, col_j); + + (0..tmp_0.cols()).for_each(|i| { + self.vec_znx_dft_add_inplace(&mut tmp_0, i, &tmp_1, i); + }); + + self.mat_znx_dft_set_row(&mut res, row_i, col_j, &tmp_0); + }); + }); + } + + fn mat_znx_dft_set_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) where R: MatZnxToMut, A: VecZnxDftToRef, @@ -204,7 +355,7 @@ impl MatZnxDftOps for Module { } } - fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + fn mat_znx_dft_get_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) where R: VecZnxDftToMut, A: MatZnxToRef, @@ -376,7 +527,7 @@ mod tests { use super::{MatZnxDftAlloc, MatZnxDftScratch}; #[test] - fn vmp_prepare_row() { + fn vmp_set_row() { let module: Module = Module::::new(16); let basek: usize = 8; let mat_rows: usize = 4; @@ -395,8 +546,8 @@ mod tests { a.fill_uniform(basek, col_out, mat_size, &mut source); module.vec_znx_dft(1, 0, &mut a_dft, col_out, &a, col_out); }); - module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft); - module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in); + module.mat_znx_dft_set_row(&mut mat, row_i, col_in, &a_dft); + module.mat_znx_dft_get_row(&mut b_dft, &mat, row_i, col_in); assert_eq!(a_dft.raw(), b_dft.raw()); } } @@ -413,10 +564,10 @@ mod tests { let mat_size: usize = 6; let res_size: usize = a_size; - [1, 2].iter().for_each(|in_cols| { - [1, 2].iter().for_each(|out_cols| { - let a_cols: usize = *in_cols; - let res_cols: usize = *out_cols; + [1, 2].iter().for_each(|cols_in| { + [1, 2].iter().for_each(|cols_out| { + let a_cols: usize = *cols_in; + let res_cols: usize = *cols_out; let mat_rows: usize = a_size; let mat_cols_in: usize = a_cols; @@ -456,7 +607,7 @@ mod tests { module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; }); - module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); }); }); @@ -499,11 +650,11 @@ mod tests { let res_size: usize = a_size; let mut source: Source = Source::new([0u8; 32]); - [1, 2].iter().for_each(|in_cols| { - [1, 2].iter().for_each(|out_cols| { + [1, 2].iter().for_each(|cols_in| { + [1, 2].iter().for_each(|cols_out| { (0..res_size).for_each(|shift| { - let a_cols: usize = *in_cols; - let res_cols: usize = *out_cols; + let a_cols: usize = *cols_in; + let res_cols: usize = *cols_out; let mat_rows: usize = a_size; let mat_cols_in: usize = a_cols; @@ -543,7 +694,7 @@ mod tests { module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; }); - module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); }); }); @@ -601,13 +752,13 @@ mod tests { let mat_size: usize = 6; let res_size: usize = a_size; - [1, 2].iter().for_each(|in_cols| { - [1, 2].iter().for_each(|out_cols| { + [1, 2].iter().for_each(|cols_in| { + [1, 2].iter().for_each(|cols_out| { [1, 3, 6].iter().for_each(|digits| { let mut source: Source = Source::new([0u8; 32]); - let a_cols: usize = *in_cols; - let res_cols: usize = *out_cols; + let a_cols: usize = *cols_in; + let res_cols: usize = *cols_out; let mat_rows: usize = a_size; let mat_cols_in: usize = a_cols; @@ -652,7 +803,7 @@ mod tests { module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); tmp.at_mut(col_out_i, limb)[idx] = 0 as i64; }); - module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + module.mat_znx_dft_set_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); }); }); @@ -697,4 +848,149 @@ mod tests { }); }); } + + #[test] + fn mat_znx_dft_mul_x_pow_minus_one() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let basek: usize = 8; + let rows: usize = 2; + let cols_in: usize = 2; + let cols_out: usize = 2; + let size: usize = 4; + + let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out)); + + let mut mat_want: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); + let mut mat_have: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); + + let mut tmp: VecZnx> = module.new_vec_znx(1, size); + let mut tmp_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(cols_out, size); + + let mut source: Source = Source::new([0u8; 32]); + + (0..mat_want.rows()).for_each(|row_i| { + (0..mat_want.cols_in()).for_each(|col_i| { + (0..cols_out).for_each(|j| { + tmp.fill_uniform(basek, 0, size, &mut source); + module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0); + }); + + module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft); + }); + }); + + let k: i64 = 1; + + module.mat_znx_dft_mul_x_pow_minus_one(k, &mut mat_have, &mat_want, scratch.borrow()); + + let mut have: VecZnx> = module.new_vec_znx(cols_out, size); + let mut want: VecZnx> = module.new_vec_znx(cols_out, size); + let mut tmp_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, size); + + (0..mat_want.rows()).for_each(|row_i| { + (0..mat_want.cols_in()).for_each(|col_i| { + module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i); + + (0..cols_out).for_each(|j| { + module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); + module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow()); + module.vec_znx_rotate(k, &mut want, j, &tmp, 0); + module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0); + module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow()); + }); + + module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i); + + (0..cols_out).for_each(|j| { + module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); + module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow()); + }); + + assert_eq!(have, want) + }); + }); + } + + #[test] + fn mat_znx_dft_mul_x_pow_minus_one_add_inplace() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let basek: usize = 8; + let rows: usize = 2; + let cols_in: usize = 2; + let cols_out: usize = 2; + let size: usize = 4; + + let mut scratch: ScratchOwned = ScratchOwned::new(module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, cols_out)); + + let mut mat_want: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); + let mut mat_have: MatZnxDft, FFT64> = module.new_mat_znx_dft(rows, cols_in, cols_out, size); + + let mut tmp: VecZnx> = module.new_vec_znx(1, size); + let mut tmp_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(cols_out, size); + + let mut source: Source = Source::new([0u8; 32]); + + (0..mat_have.rows()).for_each(|row_i| { + (0..mat_have.cols_in()).for_each(|col_i| { + (0..cols_out).for_each(|j| { + tmp.fill_uniform(basek, 0, size, &mut source); + module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0); + }); + + module.mat_znx_dft_set_row(&mut mat_have, row_i, col_i, &tmp_dft); + }); + }); + + (0..mat_want.rows()).for_each(|row_i| { + (0..mat_want.cols_in()).for_each(|col_i| { + (0..cols_out).for_each(|j| { + tmp.fill_uniform(basek, 0, size, &mut source); + module.vec_znx_dft(1, 0, &mut tmp_dft, j, &tmp, 0); + }); + + module.mat_znx_dft_set_row(&mut mat_want, row_i, col_i, &tmp_dft); + }); + }); + + let k: i64 = 1; + + module.mat_znx_dft_mul_x_pow_minus_one_add_inplace(k, &mut mat_have, &mat_want, scratch.borrow()); + + let mut have: VecZnx> = module.new_vec_znx(cols_out, size); + let mut want: VecZnx> = module.new_vec_znx(cols_out, size); + let mut tmp_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, size); + + let mut source: Source = Source::new([0u8; 32]); + (0..mat_want.rows()).for_each(|row_i| { + (0..mat_want.cols_in()).for_each(|col_i| { + module.mat_znx_dft_get_row(&mut tmp_dft, &mat_want, row_i, col_i); + + (0..cols_out).for_each(|j| { + module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); + module.vec_znx_big_normalize(basek, &mut tmp, 0, &tmp_big, 0, scratch.borrow()); + module.vec_znx_rotate(k, &mut want, j, &tmp, 0); + module.vec_znx_sub_ab_inplace(&mut want, j, &tmp, 0); + + tmp.fill_uniform(basek, 0, size, &mut source); + module.vec_znx_add_inplace(&mut want, j, &tmp, 0); + module.vec_znx_normalize_inplace(basek, &mut want, j, scratch.borrow()); + }); + + module.mat_znx_dft_get_row(&mut tmp_dft, &mat_have, row_i, col_i); + + (0..cols_out).for_each(|j| { + module.vec_znx_idft(&mut tmp_big, 0, &tmp_dft, j, scratch.borrow()); + module.vec_znx_big_normalize(basek, &mut have, j, &tmp_big, 0, scratch.borrow()); + }); + + assert_eq!(have, want) + }); + }); + } } diff --git a/backend/src/scalar_znx.rs b/backend/src/scalar_znx.rs index cb51e0d..4acedb5 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/scalar_znx.rs @@ -72,15 +72,44 @@ impl + AsRef<[u8]>> ScalarZnx { .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); self.at_mut(col, 0).shuffle(source); } + + pub fn fill_binary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { + let choices: [i64; 2] = [0, 1]; + let weights: [f64; 2] = [1.0 - prob, prob]; + let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); + self.at_mut(col, 0) + .iter_mut() + .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); + } + + pub fn fill_binary_hw(&mut self, col: usize, hw: usize, source: &mut Source) { + assert!(hw <= self.n()); + self.at_mut(col, 0)[..hw] + .iter_mut() + .for_each(|x: &mut i64| *x = (source.next_u32() & 1) as i64); + self.at_mut(col, 0).shuffle(source); + } + + pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) { + assert!(self.n() % block_size == 0); + let max_idx: u64 = (block_size + 1) as u64; + let mask_idx: u64 = (1 << ((u64::BITS - max_idx.leading_zeros()) as u64)) - 1; + for block in self.at_mut(col, 0).chunks_mut(block_size) { + let idx: usize = source.next_u64n(max_idx, mask_idx) as usize; + if idx != block_size { + block[idx] = 1; + } + } + } } impl>> ScalarZnx { - pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { - n * cols * size_of::() + pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { + n * cols * size_of::() } - pub(crate) fn new(n: usize, cols: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of::(n, cols)); + pub fn new(n: usize, cols: usize) -> Self { + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols)); Self { data: data.into(), n, @@ -88,9 +117,9 @@ impl>> ScalarZnx { } } - pub(crate) fn new_from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { + pub(crate) fn new_from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of::(n, cols)); + assert!(data.len() == Self::bytes_of(n, cols)); Self { data: data.into(), n, @@ -102,7 +131,7 @@ impl>> ScalarZnx { pub type ScalarZnxOwned = ScalarZnx>; pub(crate) fn bytes_of_scalar_znx(module: &Module, cols: usize) -> usize { - ScalarZnxOwned::bytes_of::(module.n(), cols) + ScalarZnxOwned::bytes_of(module.n(), cols) } pub trait ScalarZnxAlloc { @@ -113,13 +142,13 @@ pub trait ScalarZnxAlloc { impl ScalarZnxAlloc for Module { fn bytes_of_scalar_znx(&self, cols: usize) -> usize { - ScalarZnxOwned::bytes_of::(self.n(), cols) + ScalarZnxOwned::bytes_of(self.n(), cols) } fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned { - ScalarZnxOwned::new::(self.n(), cols) + ScalarZnxOwned::new(self.n(), cols) } fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { - ScalarZnxOwned::new_from_bytes::(self.n(), cols, bytes) + ScalarZnxOwned::new_from_bytes(self.n(), cols, bytes) } } diff --git a/backend/src/scalar_znx_dft_ops.rs b/backend/src/scalar_znx_dft_ops.rs index 6d227c7..c89808d 100644 --- a/backend/src/scalar_znx_dft_ops.rs +++ b/backend/src/scalar_znx_dft_ops.rs @@ -29,7 +29,7 @@ pub trait ScalarZnxDftOps { R: VecZnxDftToMut, A: ScalarZnxDftToRef; - fn svp_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + fn scalar_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where R: ScalarZnxToMut, A: ScalarZnxDftToRef; @@ -50,7 +50,7 @@ impl ScalarZnxDftAlloc for Module { } impl ScalarZnxDftOps for Module { - fn svp_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + fn scalar_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where R: ScalarZnxToMut, A: ScalarZnxDftToRef, diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index 84b9a84..74f9f86 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -111,6 +111,16 @@ impl + AsRef<[u8]>> VecZnx { } } + pub fn rotate(&mut self, k: i64) { + unsafe { + (0..self.cols()).for_each(|i| { + (0..self.size()).for_each(|j| { + znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j)); + }); + }) + } + } + pub fn rsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) { let n: usize = self.n(); let cols: usize = self.cols(); @@ -177,7 +187,7 @@ impl>> VecZnx { n * cols * size * size_of::() } - pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self { + pub fn new(n: usize, cols: usize, size: usize) -> Self { let data = alloc_aligned::(Self::bytes_of::(n, cols, size)); Self { data: data.into(), @@ -243,7 +253,12 @@ fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } -#[allow(dead_code)] +impl + AsMut<[u8]>> VecZnx { + pub fn normalize(&mut self, basek: usize, a_col: usize, tmp_bytes: &mut [u8]) { + normalize(basek, self, a_col, tmp_bytes); + } +} + fn normalize + AsRef<[u8]>>(basek: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); diff --git a/backend/src/vec_znx_dft_ops.rs b/backend/src/vec_znx_dft_ops.rs index 963de18..5892155 100644 --- a/backend/src/vec_znx_dft_ops.rs +++ b/backend/src/vec_znx_dft_ops.rs @@ -53,6 +53,22 @@ pub trait VecZnxDftOps { R: VecZnxDftToMut, A: VecZnxDftToRef; + fn vec_znx_dft_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef; + + fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + + fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, @@ -150,6 +166,86 @@ impl VecZnxDftOps for Module { } } + fn vec_znx_dft_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + + fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + } + + fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + } + fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, diff --git a/backend/src/znx_base.rs b/backend/src/znx_base.rs index daa313d..fa7ec49 100644 --- a/backend/src/znx_base.rs +++ b/backend/src/znx_base.rs @@ -57,8 +57,8 @@ pub trait ZnxView: ZnxInfos + DataView> { fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { #[cfg(debug_assertions)] { - assert!(i < self.cols()); - assert!(j < self.size()); + assert!(i < self.cols(), "{} >= {}", i, self.cols()); + assert!(j < self.size(), "{} >= {}", j, self.size()); } let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_ptr().add(offset) } @@ -85,8 +85,8 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { #[cfg(debug_assertions)] { - assert!(i < self.cols()); - assert!(j < self.size()); + assert!(i < self.cols(), "{} >= {}", i, self.cols()); + assert!(j < self.size(), "{} >= {}", j, self.size()); } let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_mut_ptr().add(offset) } diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index e81829c..fd6508a 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -1,5 +1,5 @@ use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned}; -use core::{GGSWCiphertext, GLWECiphertext, GLWESecret, Infos}; +use core::{FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWESecret, Infos}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use sampling::source::Source; use std::hint::black_box; @@ -26,7 +26,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let rank: usize = p.rank; let digits: usize = 1; - let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek; + let rows: usize = 1; //(p.k_ct_in + p.basek - 1) / p.basek; let sigma: f64 = 3.2; let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); @@ -52,13 +52,14 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw.encrypt_sk( &module, &pt_rgsw, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -67,7 +68,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { ct_glwe_in.encrypt_zero_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -80,11 +81,11 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { } let params_set: Vec = vec![Params { - log_n: 10, - basek: 7, - k_ct_in: 27, - k_ct_out: 27, - k_ggsw: 27, + log_n: 11, + basek: 22, + k_ct_in: 44, + k_ct_out: 44, + k_ggsw: 54, rank: 1, }]; @@ -134,13 +135,14 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw.encrypt_sk( &module, &pt_rgsw, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -149,7 +151,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { ct_glwe.encrypt_zero_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs index 66a1d02..4acc754 100644 --- a/core/benches/keyswitch_glwe_fft64.rs +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -1,5 +1,5 @@ use backend::{FFT64, Module, ScratchOwned}; -use core::{AutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos}; +use core::{FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use sampling::source::Source; use std::{hint::black_box, time::Duration}; @@ -32,12 +32,13 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let rows: usize = (p.k_ct_in + (p.basek * digits) - 1) / (p.basek * digits); let sigma: f64 = 3.2; - let mut ksk: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_grlwe, rows, digits, rank_out); + let mut ksk: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_grlwe, rows, digits, rank_out); let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_in, rank_in); let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_out, rank_out); let mut scratch = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_out) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( &module, @@ -55,13 +56,14 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, -1, &sk_in, @@ -73,7 +75,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { ct_in.encrypt_zero_sk( &module, - &sk_in, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, @@ -137,7 +139,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); let mut scratch = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), digits, rank), ); @@ -146,16 +148,18 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(&&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(&&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -164,7 +168,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { ct.encrypt_zero_sk( &module, - &sk_in, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 27ea44a..3032532 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -2,7 +2,7 @@ use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, ScalarZnxOps, Scr use sampling::source::Source; use crate::{ - GGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, + FourierGLWECiphertext, GGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow, }; @@ -68,9 +68,9 @@ impl> GetRow for AutomorphismKey { module: &Module, row_i: usize, col_j: usize, - res: &mut GLWECiphertextFourier, + res: &mut FourierGLWECiphertext, ) { - module.vmp_extract_row(&mut res.data, &self.key.0.data, row_i, col_j); + module.mat_znx_dft_get_row(&mut res.data, &self.key.0.data, row_i, col_j); } } @@ -80,9 +80,9 @@ impl + AsRef<[u8]>> SetRow for AutomorphismKey { module: &Module, row_i: usize, col_j: usize, - a: &GLWECiphertextFourier, + a: &FourierGLWECiphertext, ) { - module.vmp_prepare_row(&mut self.key.0.data, row_i, col_j, &a.data); + module.mat_znx_dft_set_row(&mut self.key.0.data, row_i, col_j, &a.data); } } @@ -127,8 +127,8 @@ impl AutomorphismKey, FFT64> { digits: usize, rank: usize, ) -> usize { - let tmp_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank); - let tmp_idft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let tmp_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); + let tmp_idft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); let idft: usize = module.vec_znx_idft_tmp_bytes(); let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); tmp_dft + tmp_idft + idft + keyswitch diff --git a/core/src/blind_rotation/cggi.rs b/core/src/blind_rotation/cggi.rs new file mode 100644 index 0000000..2cf1d88 --- /dev/null +++ b/core/src/blind_rotation/cggi.rs @@ -0,0 +1,463 @@ +use backend::{ + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView, + ZnxViewMut, ZnxZero, +}; +use itertools::izip; + +use crate::{ + GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore, + blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, + dist::Distribution, + lwe::ciphertext::LWECiphertextToRef, +}; + +pub fn cggi_blind_rotate_scratch_space( + module: &Module, + block_size: usize, + extension_factor: usize, + basek: usize, + k_res: usize, + k_brk: usize, + rows: usize, + rank: usize, +) -> usize { + let brk_size: usize = k_brk.div_ceil(basek); + + if block_size > 1 { + let cols: usize = rank + 1; + let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor; + let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size); + let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor; + let acc_dft_add: usize = vmp_res; + let xai_plus_y: usize = module.bytes_of_scalar_znx_dft(1); + let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); + let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + + let acc: usize; + if extension_factor > 1 { + acc = module.bytes_of_vec_znx(cols, k_res.div_ceil(basek)) * extension_factor; + } else { + acc = 0; + } + + return acc + + acc_dft + + acc_dft_add + + vmp_res + + xai_plus_y + + xai_plus_y_dft + + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()))); + } else { + 2 * GLWECiphertext::bytes_of(module, basek, k_res, rank) + + GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank) + } +} + +pub fn cggi_blind_rotate( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, + DataBrk: AsRef<[u8]>, +{ + match brk.dist { + Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => { + if lut.extension_factor() > 1 { + cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch); + } else if brk.block_size() > 1 { + cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); + } else { + cggi_blind_rotate_binary_standard(module, res, lwe, lut, brk, scratch); + } + } + // TODO: ternary distribution ? + _ => panic!( + "invalid BlindRotationKeyCGGI distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" + ), + } +} + +pub(crate) fn cggi_blind_rotate_block_binary_extended( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, + DataBrk: AsRef<[u8]>, +{ + let extension_factor: usize = lut.extension_factor(); + let basek: usize = res.basek(); + let rows: usize = brk.rows(); + let cols: usize = res.rank() + 1; + + let (mut acc, scratch1) = scratch.tmp_slice_vec_znx(extension_factor, module, cols, res.size()); + let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows); + let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size()); + let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size()); + let (mut minus_one, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); + let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1); + + minus_one.raw_mut()[..module.n() >> 1].fill(-1.0); + + (0..extension_factor).for_each(|i| { + acc[i].zero(); + }); + + let x_pow_a: &Vec, FFT64>>; + if let Some(b) = &brk.x_pow_a { + x_pow_a = b + } else { + panic!("invalid key: x_pow_a has not been initialized") + } + + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + + let two_n: usize = 2 * module.n(); + let two_n_ext: usize = 2 * lut.domain_size(); + + negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref); + + let a: &[i64] = &lwe_2n[1..]; + let b_pos: usize = ((lwe_2n[0] + two_n_ext as i64) & (two_n_ext - 1) as i64) as usize; + + let b_hi: usize = b_pos / extension_factor; + let b_lo: usize = b_pos & (extension_factor - 1); + + for (i, j) in (0..b_lo).zip(extension_factor - b_lo..extension_factor) { + module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i], 0, &lut.data[j], 0); + } + for (i, j) in (b_lo..extension_factor).zip(0..extension_factor - b_lo) { + module.vec_znx_rotate(b_hi as i64, &mut acc[i], 0, &lut.data[j], 0); + } + + let block_size: usize = brk.block_size(); + + izip!( + a.chunks_exact(block_size), + brk.data.chunks_exact(block_size) + ) + .for_each(|(ai, ski)| { + (0..extension_factor).for_each(|i| { + (0..cols).for_each(|j| { + module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j); + }); + acc_add_dft[i].zero(); + }); + + // TODO: first & last iterations can be optimized + izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { + let ai_pos: usize = ((aii + two_n_ext as i64) & (two_n_ext - 1) as i64) as usize; + let ai_hi: usize = ai_pos / extension_factor; + let ai_lo: usize = ai_pos & (extension_factor - 1); + + // vmp_res = DFT(acc) * BRK[i] + (0..extension_factor).for_each(|i| { + module.vmp_apply(&mut vmp_res[i], &acc_dft[i], &skii.data, scratch6); + }); + + // Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1) + if ai_lo == 0 { + // Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1) + if ai_hi != 0 { + // DFT X^{-ai} + module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_hi], 0, &minus_one, 0); + (0..extension_factor).for_each(|j| { + (0..cols).for_each(|i| { + module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i); + }); + }); + } + + // Non trivial case: rotation between polynomials + // In this case we can't directly multiply with (X^{-ai} - 1) because of the + // ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the + // computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai} + } else { + // Sets acc_add_dft[i] = acc[i] * sk + + // Sets acc_add_dft[0..ai_lo] -= acc[..ai_lo] * sk + if (ai_hi + 1) & (two_n - 1) != 0 { + for i in 0..ai_lo { + (0..cols).for_each(|k| { + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); + }); + } + } + + // Sets acc_add_dft[ai_lo..extension_factor] -= acc[ai_lo..extension_factor] * sk + if ai_hi != 0 { + for i in ai_lo..extension_factor { + (0..cols).for_each(|k: usize| { + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); + }); + } + } + + // Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1} + if (ai_hi + 1) & (two_n - 1) != 0 { + for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) { + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi + 1], 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); + }); + } + } + + // Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai} + if ai_hi != 0 { + // Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai} + for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) { + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi], 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k); + }); + } + } + } + }); + + { + let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size()); + + (0..extension_factor).for_each(|j| { + (0..cols).for_each(|i| { + module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7); + module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i); + module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7); + }); + }); + } + }); + + (0..cols).for_each(|i| { + module.vec_znx_copy(&mut res.data, i, &acc[0], i); + }); +} + +pub(crate) fn cggi_blind_rotate_block_binary( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, + DataBrk: AsRef<[u8]>, +{ + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); + let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let two_n: usize = module.n() << 1; + let basek: usize = brk.basek(); + let rows = brk.rows(); + + let cols: usize = out_mut.rank() + 1; + + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref); + + let a: &[i64] = &lwe_2n[1..]; + let b: i64 = lwe_2n[0]; + + out_mut.data.zero(); + + // Initialize out to X^{b} * LUT(X) + module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); + + let block_size: usize = brk.block_size(); + + // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] + + let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows); + let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size()); + let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size()); + let (mut minus_one, scratch4) = scratch3.tmp_scalar_znx_dft(module, 1); + let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); + + minus_one.raw_mut()[..module.n() >> 1].fill(-1.0); + + let x_pow_a: &Vec, FFT64>>; + if let Some(b) = &brk.x_pow_a { + x_pow_a = b + } else { + panic!("invalid key: x_pow_a has not been initialized") + } + + izip!( + a.chunks_exact(block_size), + brk.data.chunks_exact(block_size) + ) + .for_each(|(ai, ski)| { + (0..cols).for_each(|j| { + module.vec_znx_dft(1, 0, &mut acc_dft, j, &out_mut.data, j); + }); + + acc_add_dft.zero(); + + izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { + let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize; + + // vmp_res = DFT(acc) * BRK[i] + module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5); + + // DFT(X^ai -1) + module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_pos], 0, &minus_one, 0); + + // DFT(X^ai -1) * (DFT(acc) * BRK[i]) + (0..cols).for_each(|i| { + module.svp_apply_inplace(&mut vmp_res, i, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_res, i); + }); + }); + + { + let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size()); + + (0..cols).for_each(|i| { + module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft, i, scratch6); + module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i); + module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch6); + }); + } + }); +} + +pub(crate) fn cggi_blind_rotate_binary_standard( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, + DataBrk: AsRef<[u8]>, +{ + #[cfg(debug_assertions)] + { + assert_eq!( + res.n(), + module.n(), + "res.n(): {} != brk.n(): {}", + res.n(), + module.n() + ); + assert_eq!( + lut.domain_size(), + module.n(), + "lut.n(): {} != brk.n(): {}", + lut.domain_size(), + module.n() + ); + assert_eq!( + brk.n(), + module.n(), + "brk.n(): {} != brk.n(): {}", + brk.n(), + module.n() + ); + assert_eq!( + res.rank(), + brk.rank(), + "res.rank(): {} != brk.rank(): {}", + res.rank(), + brk.rank() + ); + assert_eq!( + lwe.n(), + brk.data.len(), + "lwe.n(): {} != brk.data.len(): {}", + lwe.n(), + brk.data.len() + ); + } + + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); + let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let basek: usize = brk.basek(); + + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref); + + let a: &[i64] = &lwe_2n[1..]; + let b: i64 = lwe_2n[0]; + + out_mut.data.zero(); + + // Initialize out to X^{b} * LUT(X) + module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); + + // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] + let (mut acc_tmp, scratch1) = scratch.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + let (mut acc_tmp_rot, scratch2) = scratch1.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + + // TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs + // TODO: first iteration can be optimized to be a gglwe product + izip!(a.iter(), brk.data.iter()).for_each(|(ai, ski)| { + // acc_tmp = sk[i] * acc + acc_tmp.external_product(module, &out_mut, ski, scratch2); + + // acc_tmp = (sk[i] * acc) * X^{ai} + acc_tmp_rot.rotate(module, *ai, &acc_tmp); + + // acc = acc + (sk[i] * acc) * X^{ai} + out_mut.add_inplace(module, &acc_tmp_rot); + + // acc = acc + (sk[i] * acc) * X^{ai} - (sk[i] * acc) = acc + (sk[i] * acc) * (X^{ai} - 1) + out_mut.sub_inplace_ab(module, &acc_tmp); + }); + + // We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}] + // on top of each others, thus ~ 2^{63-basek} additions are supported before overflow. + out_mut.normalize_inplace(module, scratch2); +} + +pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { + let basek: usize = lwe.basek(); + + let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; + + res.copy_from_slice(&lwe.data.at(0, 0)); + res.iter_mut().for_each(|x| *x = -*x); + + if basek > log2n { + let diff: usize = basek - log2n; + res.iter_mut().for_each(|x| { + *x = div_round_by_pow2(x, diff); + }) + } else { + let rem: usize = basek - (log2n % basek); + let size: usize = log2n.div_ceil(basek); + (1..size).for_each(|i| { + if i == size - 1 && rem != basek { + let k_rem: usize = basek - rem; + izip!(lwe.data.at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(lwe.data.at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { + *y = (*y << basek) + x; + }); + } + }) + } +} + +#[inline(always)] +fn div_round_by_pow2(x: &i64, k: usize) -> i64 { + (x + (1 << (k - 1))) >> k +} diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs new file mode 100644 index 0000000..01511c3 --- /dev/null +++ b/core/src/blind_rotation/key.rs @@ -0,0 +1,159 @@ +use backend::{ + Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxToRef, Scratch, + ZnxView, ZnxViewMut, +}; +use sampling::source::Source; + +use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, Infos, LWESecret}; + +pub struct BlindRotationKeyCGGI { + pub(crate) data: Vec>, + pub(crate) dist: Distribution, + pub(crate) x_pow_a: Option, B>>>, +} + +// pub struct BlindRotationKeyFHEW { +// pub(crate) data: Vec, B>>, +// pub(crate) auto: Vec, B>>, +//} + +impl BlindRotationKeyCGGI, FFT64> { + pub fn allocate(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + let mut data: Vec, FFT64>> = Vec::with_capacity(n_lwe); + (0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank))); + Self { + data, + dist: Distribution::NONE, + x_pow_a: None::, FFT64>>>, + } + } + + pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) + } +} + +impl> BlindRotationKeyCGGI { + #[allow(dead_code)] + pub(crate) fn n(&self) -> usize { + self.data[0].n() + } + + #[allow(dead_code)] + pub(crate) fn rows(&self) -> usize { + self.data[0].rows() + } + + #[allow(dead_code)] + pub(crate) fn k(&self) -> usize { + self.data[0].k() + } + + #[allow(dead_code)] + pub(crate) fn size(&self) -> usize { + self.data[0].size() + } + + #[allow(dead_code)] + pub(crate) fn rank(&self) -> usize { + self.data[0].rank() + } + + pub(crate) fn basek(&self) -> usize { + self.data[0].basek() + } + + pub(crate) fn block_size(&self) -> usize { + match self.dist { + Distribution::BinaryBlock(value) => value, + _ => 1, + } + } +} + +impl + AsMut<[u8]>> BlindRotationKeyCGGI { + pub fn generate_from_sk( + &mut self, + module: &Module, + sk_glwe: &FourierGLWESecret, + sk_lwe: &LWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DataSkGLWE: AsRef<[u8]>, + DataSkLWE: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.data.len(), sk_lwe.n()); + assert_eq!(sk_glwe.n(), module.n()); + assert_eq!(sk_glwe.rank(), self.data[0].rank()); + match sk_lwe.dist { + Distribution::BinaryBlock(_) + | Distribution::BinaryFixed(_) + | Distribution::BinaryProb(_) + | Distribution::ZERO => {} + _ => panic!( + "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" + ), + } + } + + self.dist = sk_lwe.dist; + + let mut pt: ScalarZnx> = module.new_scalar_znx(1); + let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref(); + + self.data.iter_mut().enumerate().for_each(|(i, ggsw)| { + pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i]; + ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch); + }); + + match sk_lwe.dist { + Distribution::BinaryBlock(_) => { + let mut x_pow_a: Vec, FFT64>> = Vec::with_capacity(module.n() << 1); + let mut buf: ScalarZnx> = module.new_scalar_znx(1); + (0..module.n() << 1).for_each(|i| { + let mut res: ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); + set_xai_plus_y(module, i, 0, &mut res, &mut buf); + x_pow_a.push(res); + }); + self.x_pow_a = Some(x_pow_a); + } + _ => {} + } + } +} + +pub fn set_xai_plus_y(module: &Module, ai: usize, y: i64, res: &mut ScalarZnxDft, buf: &mut ScalarZnx) +where + A: AsRef<[u8]> + AsMut<[u8]>, + B: AsRef<[u8]> + AsMut<[u8]>, +{ + let n: usize = module.n(); + + { + let raw: &mut [i64] = buf.at_mut(0, 0); + if ai < n { + raw[ai] = 1; + } else { + raw[(ai - n) & (n - 1)] = -1; + } + raw[0] += y; + } + + module.svp_prepare(res, 0, buf, 0); + + { + let raw: &mut [i64] = buf.at_mut(0, 0); + + if ai < n { + raw[ai] = 0; + } else { + raw[(ai - n) & (n - 1)] = 0; + } + raw[0] = 0; + } +} diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs new file mode 100644 index 0000000..300aa6e --- /dev/null +++ b/core/src/blind_rotation/lut.rs @@ -0,0 +1,126 @@ +use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; + +pub struct LookUpTable { + pub(crate) data: Vec>>, + pub(crate) basek: usize, + pub(crate) k: usize, +} + +impl LookUpTable { + pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self { + #[cfg(debug_assertions)] + { + assert!( + extension_factor & (extension_factor - 1) == 0, + "extension_factor must be a power of two but is: {}", + extension_factor + ); + } + let size: usize = k.div_ceil(basek); + let mut data: Vec>> = Vec::with_capacity(extension_factor); + (0..extension_factor).for_each(|_| { + data.push(module.new_vec_znx(1, size)); + }); + Self { data, basek, k } + } + + pub fn log_extension_factor(&self) -> usize { + (usize::BITS - (self.extension_factor() - 1).leading_zeros()) as _ + } + + pub fn extension_factor(&self) -> usize { + self.data.len() + } + + pub fn domain_size(&self) -> usize { + self.data.len() * self.data[0].n() + } + + pub fn set(&mut self, module: &Module, f: &Vec, k: usize) { + assert!(f.len() <= module.n()); + + let basek: usize = self.basek; + + // Get the number minimum limb to store the message modulus + let limbs: usize = k.div_ceil(1 << basek); + + #[cfg(debug_assertions)] + { + assert!(limbs <= self.data[0].size()); + } + + // Scaling factor + let scale: i64 = 1 << (k % basek) as i64; + + // #elements in lookup table + let f_len: usize = f.len(); + + // If LUT size > module.n() + let domain_size: usize = self.domain_size(); + + let size: usize = self.k.div_ceil(self.basek); + + // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) + let mut lut_full: VecZnx> = VecZnx::new::(domain_size, 1, size); + + let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); + + f.iter().enumerate().for_each(|(i, fi)| { + let start: usize = (i * domain_size).div_round(f_len); + let end: usize = ((i + 1) * domain_size).div_round(f_len); + lut_at[start..end].fill(fi * scale); + }); + + // Rotates half the step to the left + let half_step: usize = domain_size.div_round(f_len << 1); + + lut_full.rotate(-(half_step as i64)); + + let mut tmp_bytes: Vec = alloc_aligned(lut_full.n() * size_of::()); + lut_full.normalize(self.basek, 0, &mut tmp_bytes); + + if self.extension_factor() > 1 { + (0..self.extension_factor()).for_each(|i| { + module.switch_degree(&mut self.data[i], 0, &lut_full, 0); + if i < self.extension_factor() { + lut_full.rotate(-1); + } + }); + } else { + module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); + } + } + + #[allow(dead_code)] + pub(crate) fn rotate(&mut self, k: i64) { + let extension_factor: usize = self.extension_factor(); + let two_n: usize = 2 * self.data[0].n(); + let two_n_ext: usize = two_n * extension_factor; + + let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize; + + let k_hi: usize = k_pos / extension_factor; + let k_lo: usize = k_pos % extension_factor; + + (0..extension_factor - k_lo).for_each(|i| { + self.data[i].rotate(k_hi as i64); + }); + + (extension_factor - k_lo..extension_factor).for_each(|i| { + self.data[i].rotate(k_hi as i64 + 1); + }); + + self.data.rotate_right(k_lo as usize); + } +} + +pub(crate) trait DivRound { + fn div_round(self, rhs: Self) -> Self; +} + +impl DivRound for usize { + #[inline] + fn div_round(self, rhs: Self) -> Self { + (self + rhs / 2) / rhs + } +} diff --git a/core/src/blind_rotation/mod.rs b/core/src/blind_rotation/mod.rs new file mode 100644 index 0000000..bbbdd2c --- /dev/null +++ b/core/src/blind_rotation/mod.rs @@ -0,0 +1,10 @@ +pub mod cggi; +pub mod key; +pub mod lut; + +pub use cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space}; +pub use key::BlindRotationKeyCGGI; +pub use lut::LookUpTable; + +#[cfg(test)] +pub mod test_fft64; diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs new file mode 100644 index 0000000..2fbad48 --- /dev/null +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -0,0 +1,125 @@ +use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView}; +use sampling::source::Source; + +use crate::{ + FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, + blind_rotation::{ + cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n}, + key::BlindRotationKeyCGGI, + lut::LookUpTable, + }, + lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef}, +}; + +#[test] +fn standard() { + blind_rotatio_test(224, 1, 1); +} + +#[test] +fn block_binary() { + blind_rotatio_test(224, 7, 1); +} + +#[test] +fn block_binary_extended() { + blind_rotatio_test(224, 7, 2); +} + +fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) { + let module: Module = Module::::new(512); + let basek: usize = 19; + + let k_lwe: usize = 24; + let k_brk: usize = 3 * basek; + let rows_brk: usize = 2; // Ensures first limb is noise-free. + let k_lut: usize = 1 * basek; + let k_res: usize = 2 * basek; + let rank: usize = 1; + + let message_modulus: usize = 1 << 4; + + let mut source_xs: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([2u8; 32]); + let mut source_xa: Source = Source::new([1u8; 32]); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + let sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_glwe); + + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + sk_lwe.fill_binary_block(block_size, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new(BlindRotationKeyCGGI::generate_from_sk_scratch_space( + &module, basek, k_brk, rank, + )); + + let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space( + &module, + block_size, + extension_factor, + basek, + k_res, + k_brk, + rows_brk, + rank, + )); + + let mut brk: BlindRotationKeyCGGI, FFT64> = + BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank); + + brk.generate_from_sk( + &module, + &sk_glwe_dft, + &sk_lwe, + &mut source_xa, + &mut source_xe, + 3.2, + scratch.borrow(), + ); + + let mut lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe); + + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); + + let x: i64 = 2; + let bits: usize = 8; + + pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); + + lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2); + + let mut f: Vec = vec![0i64; message_modulus]; + f.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = 2 * (i as i64) + 1); + + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); + lut.set(&module, &f, message_modulus); + + let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_res, rank); + + cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow()); + + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_res); + + res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); + + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref()); + + let pt_want: i64 = (lwe_2n[0] + + lwe_2n[1..] + .iter() + .zip(sk_lwe.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::()) + & (2 * lut.domain_size() - 1) as i64; + + lut.rotate(pt_want); + + // First limb should be exactly equal (test are parameterized such that the noise does not reach + // the first limb) + assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0)); +} diff --git a/core/src/blind_rotation/test_fft64/lut.rs b/core/src/blind_rotation/test_fft64/lut.rs new file mode 100644 index 0000000..02f710d --- /dev/null +++ b/core/src/blind_rotation/test_fft64/lut.rs @@ -0,0 +1,73 @@ +use std::vec; + +use backend::{FFT64, Module, ZnxView}; + +use crate::blind_rotation::lut::{DivRound, LookUpTable}; + +#[test] +fn standard() { + let module: Module = Module::::new(32); + let basek: usize = 20; + let k_lut: usize = 40; + let message_modulus: usize = 16; + let extension_factor: usize = 1; + + let log_scale: usize = basek + 1; + + let mut f: Vec = vec![0i64; message_modulus]; + f.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i as i64) - 8); + + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); + lut.set(&module, &f, log_scale); + + let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; + lut.rotate(half_step); + + let step: usize = lut.domain_size().div_round(message_modulus); + + (0..lut.domain_size()).step_by(step).for_each(|i| { + (0..step).for_each(|_| { + assert_eq!( + f[i / step] % message_modulus as i64, + lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64 + ); + lut.rotate(-1); + }); + }); +} + +#[test] +fn extended() { + let module: Module = Module::::new(32); + let basek: usize = 20; + let k_lut: usize = 40; + let message_modulus: usize = 16; + let extension_factor: usize = 4; + + let log_scale: usize = basek + 1; + + let mut f: Vec = vec![0i64; message_modulus]; + f.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i as i64) - 8); + + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); + lut.set(&module, &f, log_scale); + + let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; + lut.rotate(half_step); + + let step: usize = lut.domain_size().div_round(message_modulus); + + (0..lut.domain_size()).step_by(step).for_each(|i| { + (0..step).for_each(|_| { + assert_eq!( + f[i / step] % message_modulus as i64, + lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64 + ); + lut.rotate(-1); + }); + }); +} diff --git a/core/src/blind_rotation/test_fft64/mod.rs b/core/src/blind_rotation/test_fft64/mod.rs new file mode 100644 index 0000000..18ac93c --- /dev/null +++ b/core/src/blind_rotation/test_fft64/mod.rs @@ -0,0 +1,2 @@ +pub mod cggi; +pub mod lut; diff --git a/core/src/dist.rs b/core/src/dist.rs new file mode 100644 index 0000000..4a97369 --- /dev/null +++ b/core/src/dist.rs @@ -0,0 +1,10 @@ +#[derive(Clone, Copy, Debug)] +pub(crate) enum Distribution { + TernaryFixed(usize), // Ternary with fixed Hamming weight + TernaryProb(f64), // Ternary with probabilistic Hamming weight + BinaryFixed(usize), // Binary with fixed Hamming weight + BinaryProb(f64), // Binary with probabilistic Hamming weight + BinaryBlock(usize), // Binary split in block of size 2^k + ZERO, // Debug mod + NONE, // Unitialized +} diff --git a/core/src/elem.rs b/core/src/elem.rs index ae7e5f7..9a1de39 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,6 @@ use backend::{Backend, Module, ZnxInfos}; -use crate::GLWECiphertextFourier; +use crate::FourierGLWECiphertext; pub trait Infos { type Inner: ZnxInfos; @@ -56,13 +56,13 @@ pub trait SetMetaData { } pub trait GetRow { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut FourierGLWECiphertext) where R: AsMut<[u8]> + AsRef<[u8]>; } pub trait SetRow { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &FourierGLWECiphertext) where R: AsRef<[u8]>; } diff --git a/core/src/fourier_glwe/ciphertext.rs b/core/src/fourier_glwe/ciphertext.rs new file mode 100644 index 0000000..a742e31 --- /dev/null +++ b/core/src/fourier_glwe/ciphertext.rs @@ -0,0 +1,45 @@ +use backend::{Backend, Module, VecZnxDft, VecZnxDftAlloc}; + +use crate::Infos; + +pub struct FourierGLWECiphertext { + pub data: VecZnxDft, + pub basek: usize, + pub k: usize, +} + +impl FourierGLWECiphertext, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx_dft(rank + 1, k.div_ceil(basek)), + basek: basek, + k: k, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + module.bytes_of_vec_znx_dft(rank + 1, k.div_ceil(basek)) + } +} + +impl Infos for FourierGLWECiphertext { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl FourierGLWECiphertext { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} diff --git a/core/src/fourier_glwe/decryption.rs b/core/src/fourier_glwe/decryption.rs new file mode 100644 index 0000000..6c18383 --- /dev/null +++ b/core/src/fourier_glwe/decryption.rs @@ -0,0 +1,84 @@ +use backend::{ + FFT64, Module, ScalarZnxDftOps, Scratch, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, + VecZnxDftOps, ZnxZero, +}; + +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, Infos}; + +impl FourierGLWECiphertext, FFT64> { + pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + let size: usize = k.div_ceil(basek); + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, size) + | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, size) + } +} + +impl> FourierGLWECiphertext { + pub fn decrypt + AsMut<[u8]>, DataSk: AsRef<[u8]>>( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk: &FourierGLWESecret, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk.n(), module.n()); + } + + let cols = self.rank() + 1; + + let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + pt_big.zero(); + + { + (1..cols).for_each(|i| { + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.svp_apply(&mut ci_dft, 0, &sk.data, i - 1, &self.data, i); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0); + }); + } + + { + let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size()); + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_idft(&mut c0_big, 0, &self.data, 0, scratch_2); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0); + } + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut pt_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } + + #[allow(dead_code)] + pub(crate) fn idft + AsMut<[u8]>>( + &self, + module: &Module, + res: &mut GLWECiphertext, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size); + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_idft(&mut res_big, 0, &self.data, i, scratch1); + module.vec_znx_big_normalize(self.basek(), &mut res.data, i, &res_big, 0, scratch1); + }); + } +} diff --git a/core/src/fourier_glwe/encryption.rs b/core/src/fourier_glwe/encryption.rs new file mode 100644 index 0000000..fd08709 --- /dev/null +++ b/core/src/fourier_glwe/encryption.rs @@ -0,0 +1,32 @@ +use backend::{FFT64, Module, Scratch, VecZnxAlloc, VecZnxBigScratch, VecZnxDftOps}; +use sampling::source::Source; + +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, Infos, ScratchCore}; + +impl FourierGLWECiphertext, FFT64> { + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + module.bytes_of_vec_znx(1, k.div_ceil(basek)) + + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + } +} + +impl + AsRef<[u8]>> FourierGLWECiphertext { + pub fn encrypt_zero_sk>( + &mut self, + module: &Module, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); + tmp_ct.encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch1); + tmp_ct.dft(module, self); + } +} diff --git a/core/src/fourier_glwe/external_product.rs b/core/src/fourier_glwe/external_product.rs new file mode 100644 index 0000000..01a7371 --- /dev/null +++ b/core/src/fourier_glwe/external_product.rs @@ -0,0 +1,129 @@ +use backend::{ + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDftAlloc, VecZnxDftOps, +}; + +use crate::{FourierGLWECiphertext, GGSWCiphertext, Infos}; + +impl FourierGLWECiphertext, FFT64> { + // WARNING TODO: UPDATE + pub fn external_product_scratch_space( + module: &Module, + basek: usize, + _k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + let ggsw_size: usize = k_ggsw.div_ceil(basek); + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); + let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); + let ggsw_size: usize = k_ggsw.div_ceil(basek); + let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + + module.vmp_apply_tmp_bytes(ggsw_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); + let res_small: usize = module.bytes_of_vec_znx(rank + 1, ggsw_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + res_dft + (vmp | (res_small + normalize)) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) + } +} + +impl + AsRef<[u8]>> FourierGLWECiphertext { + pub fn external_product, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &FourierGLWECiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= FourierGLWECiphertext::external_product_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank(), + ) + ); + } + + let cols: usize = rhs.rank() + 1; + let digits = rhs.digits(); + + // Space for VMP result in DFT domain and high precision + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); + + { + (0..digits).for_each(|di| { + a_dft.set_size((lhs.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + + (0..cols).for_each(|col_i| { + module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); + } + }); + } + + // VMP result in high precision + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + // Space for VMP result normalized + let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); + module.vec_znx_dft(1, 0, &mut self.data, i, &res_small, i); + }); + } + + pub fn external_product_inplace>( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; + self.external_product(&module, &*self_ptr, rhs, scratch); + } + } +} diff --git a/core/src/fourier_glwe/keyswitch.rs b/core/src/fourier_glwe/keyswitch.rs new file mode 100644 index 0000000..3abb26e --- /dev/null +++ b/core/src/fourier_glwe/keyswitch.rs @@ -0,0 +1,56 @@ +use backend::{FFT64, Module, Scratch}; + +use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESwitchingKey, Infos, ScratchCore}; + +impl FourierGLWECiphertext, FFT64> { + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + GLWECiphertext::bytes_of(module, basek, k_out, rank_out) + + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) + } +} + +impl + AsRef<[u8]>> FourierGLWECiphertext { + pub fn keyswitch, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &FourierGLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); + tmp_ct.keyswitch_from_fourier(module, lhs, rhs, scratch1); + tmp_ct.dft(module, self); + } + + pub fn keyswitch_inplace>( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut FourierGLWECiphertext = self as *mut FourierGLWECiphertext; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } + } +} diff --git a/core/src/fourier_glwe/mod.rs b/core/src/fourier_glwe/mod.rs new file mode 100644 index 0000000..35c9905 --- /dev/null +++ b/core/src/fourier_glwe/mod.rs @@ -0,0 +1,12 @@ +pub mod ciphertext; +pub mod decryption; +pub mod encryption; +pub mod external_product; +pub mod keyswitch; +pub mod secret; + +pub use ciphertext::FourierGLWECiphertext; +pub use secret::FourierGLWESecret; + +#[cfg(test)] +pub mod test_fft64; diff --git a/core/src/fourier_glwe/secret.rs b/core/src/fourier_glwe/secret.rs new file mode 100644 index 0000000..0f28939 --- /dev/null +++ b/core/src/fourier_glwe/secret.rs @@ -0,0 +1,58 @@ +use backend::{Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ZnxInfos}; + +use crate::{GLWESecret, dist::Distribution}; + +pub struct FourierGLWESecret { + pub(crate) data: ScalarZnxDft, + pub(crate) dist: Distribution, +} + +impl FourierGLWESecret, B> { + pub fn alloc(module: &Module, rank: usize) -> Self { + Self { + data: module.new_scalar_znx_dft(rank), + dist: Distribution::NONE, + } + } + + pub fn bytes_of(module: &Module, rank: usize) -> usize { + module.bytes_of_scalar_znx_dft(rank) + } +} + +impl FourierGLWESecret, FFT64> { + pub fn from(module: &Module, sk: &GLWESecret) -> Self + where + D: AsRef<[u8]>, + { + let mut sk_dft: FourierGLWESecret, FFT64> = Self::alloc(module, sk.rank()); + sk_dft.set(module, sk); + sk_dft + } +} + +impl FourierGLWESecret { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl + AsRef<[u8]>> FourierGLWESecret { + pub(crate) fn set(&mut self, module: &Module, sk: &GLWESecret) + where + D: AsRef<[u8]>, + { + (0..self.rank()).for_each(|i| { + module.svp_prepare(&mut self.data, i, &sk.data, i); + }); + self.dist = sk.dist + } +} diff --git a/core/src/fourier_glwe/test_fft64/external_product.rs b/core/src/fourier_glwe/test_fft64/external_product.rs new file mode 100644 index 0000000..80c9c9a --- /dev/null +++ b/core/src/fourier_glwe/test_fft64/external_product.rs @@ -0,0 +1,246 @@ +use crate::{ + FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, Infos, + noise::noise_ggsw_product, +}; +use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use sampling::source::Source; + +#[test] +fn apply() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + let k_out: usize = k_ggsw; // Better capture noise. + test_apply(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); + }); + }); +} + +#[test] +fn apply_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + test_apply_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); + }); + }); +} + +fn test_apply(log_n: usize, basek: usize, k_out: usize, k_in: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = k_in.div_ceil(digits * basek); + + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); + let mut ct_in_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_out_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.data.at_mut(0, 0)[1] = 1; + + let k: i64 = 1; + + pt_rgsw.raw_mut()[0] = 1; // X^{0} + module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + | FourierGLWECiphertext::external_product_scratch_space( + &module, + basek, + ct_out.k(), + ct_in.k(), + ct_ggsw.k(), + digits, + rank, + ), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + ct_ggsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.dft(&module, &mut ct_in_dft); + ct_out_dft.external_product(&module, &ct_in_dft, &ct_ggsw, scratch.borrow()); + ct_out_dft.idft(&module, &mut ct_out, scratch.borrow()); + + ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + pt_want.rotate_inplace(&module, k); + pt_have.sub_inplace_ab(&module, &pt_want); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_in, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = k_ct.div_ceil(digits * basek); + + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.data.at_mut(0, 0)[1] = 1; + + let k: i64 = 1; + + pt_rgsw.raw_mut()[0] = 1; // X^{0} + module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | FourierGLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + ct_ggsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct, scratch.borrow()); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + pt_want.rotate_inplace(&module, k); + pt_have.sub_inplace_ab(&module, &pt_want); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); + + println!("{} {}", noise_have, noise_want); +} diff --git a/core/src/fourier_glwe/test_fft64/keyswitch.rs b/core/src/fourier_glwe/test_fft64/keyswitch.rs new file mode 100644 index 0000000..61c8b27 --- /dev/null +++ b/core/src/fourier_glwe/test_fft64/keyswitch.rs @@ -0,0 +1,235 @@ +use crate::{ + FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, + noise::log2_std_noise_gglwe_product, +}; +use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; +use sampling::source::Source; + +#[test] +fn apply() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + println!( + "test keyswitch digits: {} rank_in: {} rank_out: {}", + di, rank_in, rank_out + ); + let k_out: usize = k_ksk; // Better capture noise. + test_apply(log_n, basek, k_in, k_out, k_ksk, di, rank_in, rank_out, 3.2); + }) + }); + }); +} + +#[test] +fn apply_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 45; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!("test keyswitch_inplace digits: {} rank: {}", di, rank); + test_apply_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); + }); + }); +} + +fn test_apply( + log_n: usize, + basek: usize, + k_in: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = k_in.div_ceil(basek * digits); + + let mut ksk: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); + let mut ct_glwe_dft_in: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank_in); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); + let mut ct_glwe_dft_out: FourierGLWECiphertext, FFT64> = + FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in, rank_out) + | GLWECiphertext::decrypt_scratch_space(&module, basek, k_out) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_in) + | FourierGLWECiphertext::keyswitch_scratch_space( + &module, + basek, + ct_glwe_out.k(), + ksk.k(), + ct_glwe_in.k(), + digits, + rank_in, + rank_out, + ), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_in.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_in.dft(&module, &mut ct_glwe_dft_in); + ct_glwe_dft_out.keyswitch(&module, &ct_glwe_dft_in, &ksk, scratch.borrow()); + ct_glwe_dft_out.idft(&module, &mut ct_glwe_out, scratch.borrow()); + + ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_in as f64, + k_in, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_apply_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = k_ct.div_ceil(basek * digits); + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) + | FourierGLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), digits, rank), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.keyswitch_inplace(&module, &ksk, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_glwe, scratch.borrow()); + + ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/fourier_glwe/test_fft64/mod.rs b/core/src/fourier_glwe/test_fft64/mod.rs new file mode 100644 index 0000000..784c37c --- /dev/null +++ b/core/src/fourier_glwe/test_fft64/mod.rs @@ -0,0 +1,2 @@ +pub mod external_product; +pub mod keyswitch; diff --git a/core/src/gglwe/automorphism.rs b/core/src/gglwe/automorphism.rs new file mode 100644 index 0000000..07fcb14 --- /dev/null +++ b/core/src/gglwe/automorphism.rs @@ -0,0 +1,150 @@ +use backend::{FFT64, Module, Scratch, VecZnx, VecZnxDftOps, VecZnxOps, ZnxZero}; + +use crate::{FourierGLWECiphertext, GLWEAutomorphismKey, GLWECiphertext, GetRow, Infos, ScratchCore, SetRow}; + +impl GLWEAutomorphismKey, FFT64> { + pub fn automorphism_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + let tmp_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); + let tmp_idft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let idft: usize = module.vec_znx_idft_tmp_bytes(); + let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); + tmp_dft + tmp_idft + idft + keyswitch + } + + pub fn automorphism_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank) + } +} + +impl + AsRef<[u8]>> GLWEAutomorphismKey { + pub fn automorphism, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWEAutomorphismKey, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + assert!( + self.k() <= lhs.k(), + "output k={} cannot be greater than input k={}", + self.k(), + lhs.k() + ) + } + + let cols_out: usize = rhs.rank_out() + 1; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + let (mut tmp_idft_data, scratct1) = scratch.tmp_vec_znx_big(module, cols_out, self.size()); + + { + let (mut tmp_dft, scratch2) = scratct1.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + + // Extracts relevant row + lhs.get_row(module, row_j, col_i, &mut tmp_dft); + + // Get a VecZnxBig from scratch space + + // Switches input outside of DFT + (0..cols_out).for_each(|i| { + module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); + }); + } + + // Consumes to small vec znx + let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small(); + + // Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + (0..cols_out).for_each(|i| { + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft_small_data, i); + }); + + // Wraps into ciphertext + let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: tmp_idft_small_data, + basek: self.basek(), + k: self.k(), + }; + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + tmp_idft.keyswitch_inplace(module, &rhs.key, scratct1); + + { + let (mut tmp_dft, _) = scratct1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); + + // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) + // and switches back to DFT domain + (0..self.rank_out() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i); + module.vec_znx_dft(1, 0, &mut tmp_dft.data, i, &tmp_idft.data, i); + }); + + // Sets back the relevant row + self.set_row(module, row_j, col_i, &tmp_dft); + } + }); + }); + + let (mut tmp_dft, _) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); + tmp_dft.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_dft); + }); + }); + + self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); + } + + pub fn automorphism_inplace>( + &mut self, + module: &Module, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWEAutomorphismKey = self as *mut GLWEAutomorphismKey; + self.automorphism(&module, &*self_ptr, rhs, scratch); + } + } +} diff --git a/core/src/gglwe/automorphism_key.rs b/core/src/gglwe/automorphism_key.rs new file mode 100644 index 0000000..38a7f6e --- /dev/null +++ b/core/src/gglwe/automorphism_key.rs @@ -0,0 +1,83 @@ +use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module}; + +use crate::{FourierGLWECiphertext, GLWESwitchingKey, GetRow, Infos, SetRow}; + +pub struct GLWEAutomorphismKey { + pub(crate) key: GLWESwitchingKey, + pub(crate) p: i64, +} + +impl GLWEAutomorphismKey, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { + GLWEAutomorphismKey { + key: GLWESwitchingKey::alloc(module, basek, k, rows, digits, rank, rank), + p: 0, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { + GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits, rank, rank) + } +} + +impl Infos for GLWEAutomorphismKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.key.inner() + } + + fn basek(&self) -> usize { + self.key.basek() + } + + fn k(&self) -> usize { + self.key.k() + } +} + +impl GLWEAutomorphismKey { + pub fn p(&self) -> i64 { + self.p + } + + pub fn digits(&self) -> usize { + self.key.digits() + } + + pub fn rank(&self) -> usize { + self.key.rank() + } + + pub fn rank_in(&self) -> usize { + self.key.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.key.rank_out() + } +} + +impl> GetRow for GLWEAutomorphismKey { + fn get_row + AsRef<[u8]>>( + &self, + module: &Module, + row_i: usize, + col_j: usize, + res: &mut FourierGLWECiphertext, + ) { + module.mat_znx_dft_get_row(&mut res.data, &self.key.key.data, row_i, col_j); + } +} + +impl + AsRef<[u8]>> SetRow for GLWEAutomorphismKey { + fn set_row>( + &mut self, + module: &Module, + row_i: usize, + col_j: usize, + a: &FourierGLWECiphertext, + ) { + module.mat_znx_dft_set_row(&mut self.key.key.data, row_i, col_j, &a.data); + } +} diff --git a/core/src/gglwe/ciphertext.rs b/core/src/gglwe/ciphertext.rs new file mode 100644 index 0000000..340b897 --- /dev/null +++ b/core/src/gglwe/ciphertext.rs @@ -0,0 +1,131 @@ +use backend::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, Module}; + +use crate::{FourierGLWECiphertext, GetRow, Infos, SetRow}; + +pub struct GGLWECiphertext { + pub(crate) data: MatZnxDft, + pub(crate) basek: usize, + pub(crate) k: usize, + pub(crate) digits: usize, +} + +impl GGLWECiphertext, B> { + pub fn alloc( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> Self { + let size: usize = k.div_ceil(basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + Self { + data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, size), + basek: basek, + k, + digits, + } + } + + pub fn bytes_of( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + let size: usize = k.div_ceil(basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + module.bytes_of_mat_znx_dft(rows, rank_in, rank_out + 1, rows) + } +} + +impl Infos for GGLWECiphertext { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGLWECiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } + + pub fn digits(&self) -> usize { + self.digits + } + + pub fn rank_in(&self) -> usize { + self.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.data.cols_out() - 1 + } +} + +impl> GetRow for GGLWECiphertext { + fn get_row + AsRef<[u8]>>( + &self, + module: &Module, + row_i: usize, + col_j: usize, + res: &mut FourierGLWECiphertext, + ) { + module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); + } +} + +impl + AsRef<[u8]>> SetRow for GGLWECiphertext { + fn set_row>( + &mut self, + module: &Module, + row_i: usize, + col_j: usize, + a: &FourierGLWECiphertext, + ) { + module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); + } +} diff --git a/core/src/gglwe/encryption.rs b/core/src/gglwe/encryption.rs new file mode 100644 index 0000000..3e0a7f4 --- /dev/null +++ b/core/src/gglwe/encryption.rs @@ -0,0 +1,302 @@ +use backend::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScalarZnxOps, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxOps, + ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + FourierGLWESecret, GGLWECiphertext, GLWEAutomorphismKey, GLWECiphertext, GLWESecret, GLWESwitchingKey, GLWETensorKey, Infos, + ScratchCore, SetRow, +}; + +impl GGLWECiphertext, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + let size = k.div_ceil(basek); + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + + module.bytes_of_vec_znx(rank + 1, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(rank + 1, size) + } + + pub fn encrypt_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + unimplemented!() + } +} + +impl + AsRef<[u8]>> GGLWECiphertext { + pub fn encrypt_sk, DataSk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + pt.cols(), + "self.rank_in(): {} != pt.cols(): {}", + self.rank_in(), + pt.cols() + ); + assert_eq!( + self.rank_out(), + sk.rank(), + "self.rank_out(): {} != sk.rank(): {}", + self.rank_out(), + sk.rank() + ); + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert!( + scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", + scratch.available(), + self.rank(), + self.size(), + GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + ); + assert!( + self.rows() * self.digits() * self.basek() <= self.k(), + "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}", + self.rows(), + self.digits(), + self.basek(), + self.rows() * self.digits() * self.basek(), + self.k() + ); + } + + let rows: usize = self.rows(); + let digits: usize = self.digits(); + let basek: usize = self.basek(); + let k: usize = self.k(); + let rank_in: usize = self.rank_in(); + let rank_out: usize = self.rank_out(); + + let (mut tmp_pt, scrach_1) = scratch.tmp_glwe_pt(module, basek, k); + let (mut tmp_ct, scrach_2) = scrach_1.tmp_glwe_ct(module, basek, k, rank_out); + let (mut tmp_ct_dft, scratch_3) = scrach_2.tmp_fourier_glwe_ct(module, basek, k, rank_out); + + // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns + // + // Example for ksk rank 2 to rank 3: + // + // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) + // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) + // + // Example ksk rank 2 to rank 1 + // + // (-(a*s) + s0, a) + // (-(b*s) + s1, b) + (0..rank_in).for_each(|col_i| { + (0..rows).for_each(|row_i| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + tmp_pt.data.zero(); // zeroes for next iteration + module.vec_znx_add_scalar_inplace( + &mut tmp_pt.data, + 0, + (digits - 1) + row_i * digits, + pt, + col_i, + ); + module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3); + + // rlwe encrypt of vec_znx_pt into vec_znx_ct + tmp_ct.encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, sigma, scratch_3); + + // Switch vec_znx_ct into DFT domain + tmp_ct.dft(module, &mut tmp_ct_dft); + + // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft + self.set_row(module, row_i, col_i, &tmp_ct_dft); + }); + }); + } +} + +impl GLWESwitchingKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_in: usize, rank_out: usize) -> usize { + GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k, rank_out) + + module.bytes_of_scalar_znx(rank_in) + + FourierGLWESecret::bytes_of(module, rank_out) + } + + pub fn encrypt_pk_scratch_space( + module: &Module, + _basek: usize, + _k: usize, + _rank_in: usize, + _rank_out: usize, + ) -> usize { + GGLWECiphertext::encrypt_pk_scratch_space(module, _basek, _k, _rank_out) + } +} + +impl + AsRef<[u8]>> GLWESwitchingKey { + pub fn encrypt_sk, DataSkOut: AsRef<[u8]>>( + &mut self, + module: &Module, + sk_in: &GLWESecret, + sk_out: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert!(sk_in.n() <= module.n()); + assert!(sk_out.n() <= module.n()); + } + + let (mut sk_in_tmp, scratch1) = scratch.tmp_scalar_znx(module, sk_in.rank()); + sk_in_tmp.zero(); + + (0..sk_in.rank()).for_each(|i| { + sk_in_tmp + .at_mut(i, 0) + .iter_mut() + .step_by(module.n() / sk_in.n()) + .zip(sk_in.data.at(i, 0).iter()) + .for_each(|(x, y)| *x = *y); + }); + + let (mut sk_out_tmp, scratch2) = scratch1.tmp_fourier_glwe_secret(module, sk_out.rank()); + (0..sk_out.rank()).for_each(|i| { + sk_out_tmp + .data + .at_mut(i, 0) + .chunks_exact_mut(module.n() / sk_out.n()) + .zip(sk_out.data.at(i, 0).iter()) + .for_each(|(a_chunk, &b_elem)| { + a_chunk.fill(b_elem); + }); + }); + + self.key.encrypt_sk( + module, + &sk_in_tmp, + &sk_out_tmp, + source_xa, + source_xe, + sigma, + scratch2, + ); + self.sk_in_n = sk_in.n(); + self.sk_out_n = sk_out.n(); + } +} + +impl GLWEAutomorphismKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecret::bytes_of(module, rank) + } + + pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + GLWESwitchingKey::encrypt_pk_scratch_space(module, _basek, _k, _rank, _rank) + } +} + +impl + AsRef<[u8]>> GLWEAutomorphismKey { + pub fn encrypt_sk>( + &mut self, + module: &Module, + p: i64, + sk: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.rank_out(), self.rank_in()); + assert_eq!(sk.rank(), self.rank()); + assert!( + scratch.available() >= GLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", + scratch.available(), + self.rank(), + self.size(), + GLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + ) + } + + let (mut sk_out_dft, scratch_1) = scratch.tmp_fourier_glwe_secret(module, sk.rank()); + + { + let (mut sk_out, _) = scratch_1.tmp_glwe_secret(module, sk.rank()); + (0..self.rank()).for_each(|i| { + module.scalar_znx_automorphism( + module.galois_element_inv(p), + &mut sk_out.data, + i, + &sk.data, + i, + ); + }); + sk_out_dft.set(module, &sk_out); + } + + self.key.encrypt_sk( + module, + &sk, + &sk_out_dft, + source_xa, + source_xe, + sigma, + scratch_1, + ); + + self.p = p; + } +} + +impl GLWETensorKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GLWESecret::bytes_of(module, 1) + + FourierGLWESecret::bytes_of(module, 1) + + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + } +} + +impl + AsRef<[u8]>> GLWETensorKey { + pub fn encrypt_sk>( + &mut self, + module: &Module, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + } + + let rank: usize = self.rank(); + + let (mut sk_ij, scratch1) = scratch.tmp_glwe_secret(module, 1); + let (mut sk_ij_dft, scratch2) = scratch1.tmp_fourier_glwe_secret(module, 1); + + (0..rank).for_each(|i| { + (i..rank).for_each(|j| { + module.svp_apply(&mut sk_ij_dft.data, 0, &sk.data, i, &sk.data, j); + module.scalar_znx_idft(&mut sk_ij.data, 0, &sk_ij_dft.data, 0, scratch2); + self.at_mut(i, j) + .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch2); + }); + }) + } +} diff --git a/core/src/gglwe/external_product.rs b/core/src/gglwe/external_product.rs new file mode 100644 index 0000000..26a8c92 --- /dev/null +++ b/core/src/gglwe/external_product.rs @@ -0,0 +1,162 @@ +use backend::{FFT64, Module, Scratch, ZnxZero}; + +use crate::{FourierGLWECiphertext, GGSWCiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow}; + +impl GLWESwitchingKey, FFT64> { + pub fn external_product_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); + let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ggsw: usize = FourierGLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); + tmp_in + tmp_out + ggsw + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ggsw: usize = + FourierGLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); + tmp + ggsw + } +} + +impl + AsRef<[u8]>> GLWESwitchingKey { + pub fn external_product, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank(), + "ksk_in output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (mut tmp_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.external_product(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); + } + + pub fn external_product_inplace>( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (mut tmp, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); + println!("tmp: {}", tmp.size()); + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.external_product_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); + } +} + +impl GLWEAutomorphismKey, FFT64> { + pub fn external_product_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + ggsw_k: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, ggsw_k, digits, rank) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + ggsw_k: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_out, ggsw_k, digits, rank) + } +} + +impl + AsRef<[u8]>> GLWEAutomorphismKey { + pub fn external_product, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWEAutomorphismKey, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + self.key.external_product(module, &lhs.key, rhs, scratch); + } + + pub fn external_product_inplace>( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + self.key.external_product_inplace(module, rhs, scratch); + } +} diff --git a/core/src/gglwe/keyswitch.rs b/core/src/gglwe/keyswitch.rs new file mode 100644 index 0000000..fe4a3f6 --- /dev/null +++ b/core/src/gglwe/keyswitch.rs @@ -0,0 +1,163 @@ +use backend::{FFT64, Module, Scratch, ZnxZero}; + +use crate::{FourierGLWECiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow}; + +impl GLWEAutomorphismKey, FFT64> { + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + } +} + +impl + AsRef<[u8]>> GLWEAutomorphismKey { + pub fn keyswitch, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWEAutomorphismKey, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + self.key.keyswitch(module, &lhs.key, rhs, scratch); + } + + pub fn keyswitch_inplace>( + &mut self, + module: &Module, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + self.key.keyswitch_inplace(module, &rhs.key, scratch); + } +} + +impl GLWESwitchingKey, FFT64> { + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank_in); + let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank_out); + let ksk: usize = + FourierGLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out); + tmp_in + tmp_out + ksk + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ksk: usize = FourierGLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); + tmp + ksk + } +} + +impl + AsRef<[u8]>> GLWESwitchingKey { + pub fn keyswitch, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (mut tmp_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.keyswitch(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); + } + + pub fn keyswitch_inplace>( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (mut tmp, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.keyswitch_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); + } +} diff --git a/core/src/gglwe/keyswitch_key.rs b/core/src/gglwe/keyswitch_key.rs new file mode 100644 index 0000000..cb1c1fb --- /dev/null +++ b/core/src/gglwe/keyswitch_key.rs @@ -0,0 +1,105 @@ +use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module}; + +use crate::{FourierGLWECiphertext, GGLWECiphertext, GetRow, Infos, SetRow}; + +pub struct GLWESwitchingKey { + pub(crate) key: GGLWECiphertext, + pub(crate) sk_in_n: usize, // Degree of sk_in + pub(crate) sk_out_n: usize, // Degree of sk_out +} + +impl GLWESwitchingKey, FFT64> { + pub fn alloc( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> Self { + GLWESwitchingKey { + key: GGLWECiphertext::alloc(module, basek, k, rows, digits, rank_in, rank_out), + sk_in_n: 0, + sk_out_n: 0, + } + } + + pub fn bytes_of( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + GGLWECiphertext::, FFT64>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) + } +} + +impl Infos for GLWESwitchingKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + self.key.inner() + } + + fn basek(&self) -> usize { + self.key.basek() + } + + fn k(&self) -> usize { + self.key.k() + } +} + +impl GLWESwitchingKey { + pub fn rank(&self) -> usize { + self.key.data.cols_out() - 1 + } + + pub fn rank_in(&self) -> usize { + self.key.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.key.data.cols_out() - 1 + } + + pub fn digits(&self) -> usize { + self.key.digits() + } + + pub fn sk_degree_in(&self) -> usize { + self.sk_in_n + } + + pub fn sk_degree_out(&self) -> usize { + self.sk_out_n + } +} + +impl> GetRow for GLWESwitchingKey { + fn get_row + AsRef<[u8]>>( + &self, + module: &Module, + row_i: usize, + col_j: usize, + res: &mut FourierGLWECiphertext, + ) { + module.mat_znx_dft_get_row(&mut res.data, &self.key.data, row_i, col_j); + } +} + +impl + AsRef<[u8]>> SetRow for GLWESwitchingKey { + fn set_row>( + &mut self, + module: &Module, + row_i: usize, + col_j: usize, + a: &FourierGLWECiphertext, + ) { + module.mat_znx_dft_set_row(&mut self.key.data, row_i, col_j, &a.data); + } +} diff --git a/core/src/gglwe/mod.rs b/core/src/gglwe/mod.rs new file mode 100644 index 0000000..4c2d20a --- /dev/null +++ b/core/src/gglwe/mod.rs @@ -0,0 +1,16 @@ +pub mod automorphism; +pub mod automorphism_key; +pub mod ciphertext; +pub mod encryption; +pub mod external_product; +pub mod keyswitch; +pub mod keyswitch_key; +pub mod tensor_key; + +pub use automorphism_key::GLWEAutomorphismKey; +pub use ciphertext::GGLWECiphertext; +pub use keyswitch_key::GLWESwitchingKey; +pub use tensor_key::GLWETensorKey; + +#[cfg(test)] +mod test_fft64; diff --git a/core/src/tensor_key.rs b/core/src/gglwe/tensor_key.rs similarity index 52% rename from core/src/tensor_key.rs rename to core/src/gglwe/tensor_key.rs index c0887c9..c12c1f5 100644 --- a/core/src/tensor_key.rs +++ b/core/src/gglwe/tensor_key.rs @@ -1,13 +1,12 @@ -use backend::{Backend, FFT64, MatZnxDft, Module, ScalarZnxDftOps, Scratch}; -use sampling::source::Source; +use backend::{Backend, FFT64, MatZnxDft, Module}; -use crate::{GLWESecret, GLWESwitchingKey, Infos, ScratchCore}; +use crate::{GLWESwitchingKey, Infos}; -pub struct TensorKey { +pub struct GLWETensorKey { pub(crate) keys: Vec>, } -impl TensorKey, FFT64> { +impl GLWETensorKey, FFT64> { pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let mut keys: Vec, FFT64>> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); @@ -25,7 +24,7 @@ impl TensorKey, FFT64> { } } -impl Infos for TensorKey { +impl Infos for GLWETensorKey { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { @@ -41,7 +40,7 @@ impl Infos for TensorKey { } } -impl TensorKey { +impl GLWETensorKey { pub fn rank(&self) -> usize { self.keys[0].rank() } @@ -59,50 +58,7 @@ impl TensorKey { } } -impl TensorKey, FFT64> { - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GLWESecret::bytes_of(module, 1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank) - } -} - -impl + AsRef<[u8]>> TensorKey { - pub fn generate_from_sk>( - &mut self, - module: &Module, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); - } - - let rank: usize = self.rank(); - - let (mut sk_ij, scratch1) = scratch.tmp_sk(module, 1); - - (0..rank).for_each(|i| { - (i..rank).for_each(|j| { - module.svp_apply( - &mut sk_ij.data_fourier, - 0, - &sk.data_fourier, - i, - &sk.data_fourier, - j, - ); - module.svp_idft(&mut sk_ij.data, 0, &sk_ij.data_fourier, 0, scratch1); - self.at_mut(i, j) - .generate_from_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch1); - }); - }) - } - +impl + AsRef<[u8]>> GLWETensorKey { // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { if i > j { @@ -113,7 +69,7 @@ impl + AsRef<[u8]>> TensorKey { } } -impl> TensorKey { +impl> GLWETensorKey { // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { if i > j { diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/gglwe/test_fft64/automorphism_key.rs similarity index 70% rename from core/src/test_fft64/automorphism_key.rs rename to core/src/gglwe/test_fft64/automorphism_key.rs index f23b619..c6dc212 100644 --- a/core/src/test_fft64/automorphism_key.rs +++ b/core/src/gglwe/test_fft64/automorphism_key.rs @@ -2,7 +2,8 @@ use backend::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - AutomorphismKey, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GetRow, Infos, test_fft64::log2_std_noise_gglwe_product, + FourierGLWECiphertext, FourierGLWESecret, GLWEAutomorphismKey, GLWEPlaintext, GLWESecret, GetRow, Infos, + noise::log2_std_noise_gglwe_product, }; #[test] @@ -57,27 +58,28 @@ fn test_automorphism( let rows_in: usize = k_in / (basek * digits); let rows_apply: usize = k_in.div_ceil(basek * digits); - let mut auto_key_in: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_out: AutomorphismKey, FFT64> = - AutomorphismKey::alloc(&module, basek, k_out, rows_in, digits_in, rank); - let mut auto_key_apply: AutomorphismKey, FFT64> = - AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); + let mut auto_key_in: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_out: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_out, rows_in, digits_in, rank); + let mut auto_key_apply: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) - | AutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), + GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) + | GLWEAutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 - auto_key_in.generate_from_sk( + auto_key_in.encrypt_sk( &module, p0, &sk, @@ -88,7 +90,7 @@ fn test_automorphism( ); // gglwe_{s2}(s1) -> s1 -> s2 - auto_key_apply.generate_from_sk( + auto_key_apply.encrypt_sk( &module, p1, &sk, @@ -101,10 +103,10 @@ fn test_automorphism( // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key_out.automorphism(&module, &auto_key_in, &auto_key_apply, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - let mut sk_auto: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); + let mut sk_auto: GLWESecret> = GLWESecret::alloc(&module, rank); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk (0..rank).for_each(|i| { module.scalar_znx_automorphism( @@ -116,12 +118,12 @@ fn test_automorphism( ); }); - sk_auto.prep_fourier(&module); + let sk_auto_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_auto); (0..auto_key_out.rank_in()).for_each(|col_i| { (0..auto_key_out.rows()).for_each(|row_i| { auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, @@ -173,25 +175,26 @@ fn test_automorphism_inplace( let rows_in: usize = k_in / (basek * digits); let rows_apply: usize = k_in.div_ceil(basek * digits); - let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_apply: AutomorphismKey, FFT64> = - AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); + let mut auto_key: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_apply: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_in) - | AutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), + GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_apply, rank) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_in) + | GLWEAutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 - auto_key.generate_from_sk( + auto_key.encrypt_sk( &module, p0, &sk, @@ -202,7 +205,7 @@ fn test_automorphism_inplace( ); // gglwe_{s2}(s1) -> s1 -> s2 - auto_key_apply.generate_from_sk( + auto_key_apply.encrypt_sk( &module, p1, &sk, @@ -215,11 +218,12 @@ fn test_automorphism_inplace( // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_in, rank); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_in, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut sk_auto: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); + let mut sk_auto: GLWESecret> = GLWESecret::alloc(&module, rank); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + (0..rank).for_each(|i| { module.scalar_znx_automorphism( module.galois_element_inv(p0 * p1), @@ -230,13 +234,13 @@ fn test_automorphism_inplace( ); }); - sk_auto.prep_fourier(&module); + let sk_auto_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_auto); (0..auto_key.rank_in()).for_each(|col_i| { (0..auto_key.rows()).for_each(|row_i| { auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, diff --git a/core/src/test_fft64/gglwe.rs b/core/src/gglwe/test_fft64/gglwe.rs similarity index 80% rename from core/src/test_fft64/gglwe.rs rename to core/src/gglwe/test_fft64/gglwe.rs index 91798d2..492d3b8 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/gglwe/test_fft64/gglwe.rs @@ -2,8 +2,8 @@ use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToMut, ScratchO use sampling::source::Source; use crate::{ - GGSWCiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, - test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, + FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, + noise::{log2_std_noise_gglwe_product, noise_ggsw_product}, }; #[test] @@ -144,33 +144,34 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, digits: usize, rank let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ksk), + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in, rank_out) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ksk), ); - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = + FourierGLWECiphertext::alloc(&module, basek, k_ksk, rank_out); (0..ksk.rank_in()).for_each(|col_i| { (0..ksk.rows()).for_each(|row_i| { ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, @@ -233,8 +234,13 @@ fn test_key_switch( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in_s0s1 | rank_out_s0s1) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + GLWESwitchingKey::encrypt_sk_scratch_space( + &module, + basek, + k_ksk, + rank_in_s0s1, + rank_in_s0s1 | rank_out_s0s1, + ) | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWESwitchingKey::keyswitch_scratch_space( &module, basek, @@ -247,20 +253,22 @@ fn test_key_switch( ), ); - let mut sk0: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in_s0s1); - sk0.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk0: GLWESecret> = GLWESecret::alloc(&module, rank_in_s0s1); + sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out_s0s1); - sk1.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk1: GLWESecret> = GLWESecret::alloc(&module, rank_out_s0s1); + sk1.fill_ternary_prob(0.5, &mut source_xs); + let sk1_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk1); - let mut sk2: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out_s1s2); - sk2.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk2: GLWESecret> = GLWESecret::alloc(&module, rank_out_s1s2); + sk2.fill_ternary_prob(0.5, &mut source_xs); + let sk2_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.generate_from_sk( + ct_gglwe_s0s1.encrypt_sk( &module, &sk0, - &sk1, + &sk1_dft, &mut source_xa, &mut source_xe, sigma, @@ -268,10 +276,10 @@ fn test_key_switch( ); // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.generate_from_sk( + ct_gglwe_s1s2.encrypt_sk( &module, &sk1, - &sk2, + &sk2_dft, &mut source_xa, &mut source_xe, sigma, @@ -281,14 +289,14 @@ fn test_key_switch( // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_out, rank_out_s1s2); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = + FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out_s1s2); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk2, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, @@ -345,25 +353,27 @@ fn test_key_switch_inplace( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ksk) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in, rank_out) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ksk) | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, rank_out), ); - let mut sk0: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk0.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk0: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk1.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk1: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk1.fill_ternary_prob(0.5, &mut source_xs); + let sk1_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk1); - let mut sk2: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk2.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk2: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk2.fill_ternary_prob(0.5, &mut source_xs); + let sk2_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.generate_from_sk( + ct_gglwe_s0s1.encrypt_sk( &module, &sk0, - &sk1, + &sk1_dft, &mut source_xa, &mut source_xe, sigma, @@ -371,10 +381,10 @@ fn test_key_switch_inplace( ); // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.generate_from_sk( + ct_gglwe_s1s2.encrypt_sk( &module, &sk1, - &sk2, + &sk2_dft, &mut source_xa, &mut source_xe, sigma, @@ -386,13 +396,13 @@ fn test_key_switch_inplace( let ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = ct_gglwe_s0s1; - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank_out); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk2, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, 0, @@ -454,8 +464,8 @@ fn test_external_product( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_in, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_in, rank_in, rank_out) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GLWESwitchingKey::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank_out) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), ); @@ -464,17 +474,18 @@ fn test_external_product( pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_in.generate_from_sk( + ct_gglwe_in.encrypt_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -484,7 +495,7 @@ fn test_external_product( ct_rgsw.encrypt_sk( &module, &pt_rgsw, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -494,7 +505,7 @@ fn test_external_product( // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe_out.external_product(&module, &ct_gglwe_in, &ct_rgsw, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank_out); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); (0..rank_in).for_each(|i| { @@ -504,7 +515,7 @@ fn test_external_product( (0..rank_in).for_each(|col_i| { (0..ct_gglwe_out.rows()).for_each(|row_i| { ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_out, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, @@ -574,8 +585,8 @@ fn test_external_product_inplace( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ct, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ct, rank_in, rank_out) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) | GLWESwitchingKey::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank_out) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), ); @@ -584,17 +595,18 @@ fn test_external_product_inplace( pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe.generate_from_sk( + ct_gglwe.encrypt_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -604,7 +616,7 @@ fn test_external_product_inplace( ct_rgsw.encrypt_sk( &module, &pt_rgsw, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -614,7 +626,7 @@ fn test_external_product_inplace( // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank_out); + let mut ct_glwe_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); (0..rank_in).for_each(|i| { @@ -624,7 +636,7 @@ fn test_external_product_inplace( (0..rank_in).for_each(|col_i| { (0..ct_gglwe.rows()).for_each(|row_i| { ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_out, scratch.borrow()); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( &mut pt.data, diff --git a/core/src/gglwe/test_fft64/mod.rs b/core/src/gglwe/test_fft64/mod.rs new file mode 100644 index 0000000..49d23cd --- /dev/null +++ b/core/src/gglwe/test_fft64/mod.rs @@ -0,0 +1,3 @@ +pub mod automorphism_key; +pub mod gglwe; +pub mod tensor_key; diff --git a/core/src/test_fft64/tensor_key.rs b/core/src/gglwe/test_fft64/tensor_key.rs similarity index 60% rename from core/src/test_fft64/tensor_key.rs rename to core/src/gglwe/test_fft64/tensor_key.rs index 579fec4..ab1d191 100644 --- a/core/src/test_fft64/tensor_key.rs +++ b/core/src/gglwe/test_fft64/tensor_key.rs @@ -1,7 +1,7 @@ use backend::{FFT64, Module, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; -use crate::{GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GetRow, Infos, TensorKey}; +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWEPlaintext, GLWESecret, GLWETensorKey, GetRow, Infos}; #[test] fn encrypt_sk() { @@ -17,53 +17,48 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize let rows: usize = k / basek; - let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k, rows, 1, rank); + let mut tensor_key: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k, rows, 1, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::generate_from_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::new(GLWETensorKey::encrypt_sk_scratch_space( &module, basek, tensor_key.k(), rank, )); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - tensor_key.generate_from_sk( + tensor_key.encrypt_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - let mut sk_ij = GLWESecret::alloc(&module, 1); + let mut sk_ij_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::alloc(&module, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(&module, 1); (0..rank).for_each(|i| { (0..rank).for_each(|j| { - module.svp_apply( - &mut sk_ij.data_fourier, - 0, - &sk.data_fourier, - i, - &sk.data_fourier, - j, - ); - module.svp_idft(&mut sk_ij.data, 0, &sk_ij.data_fourier, 0, scratch.borrow()); + module.svp_apply(&mut sk_ij_dft.data, 0, &sk_dft.data, i, &sk_dft.data, j); + module.scalar_znx_idft(&mut sk_ij.data, 0, &sk_ij_dft.data, 0, scratch.borrow()); (0..tensor_key.rank_in()).for_each(|col_i| { (0..tensor_key.rows()).for_each(|row_i| { tensor_key .at(i, j) .get_row(&module, row_i, col_i, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2(); assert!((sigma - std_pt).abs() <= 0.5, "{} {}", sigma, std_pt); diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs deleted file mode 100644 index 22d6749..0000000 --- a/core/src/gglwe_ciphertext.rs +++ /dev/null @@ -1,236 +0,0 @@ -use backend::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, Module, ScalarZnx, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxOps, - ZnxInfos, ZnxZero, -}; -use sampling::source::Source; - -use crate::{GLWECiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow}; - -pub struct GGLWECiphertext { - pub(crate) data: MatZnxDft, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) digits: usize, -} - -impl GGLWECiphertext, B> { - pub fn alloc( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self { - let size: usize = k.div_ceil(basek); - debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits - ); - - debug_assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size - ); - - Self { - data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, size), - basek: basek, - k, - digits, - } - } - - pub fn bytes_of( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - let size: usize = k.div_ceil(basek); - debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits - ); - - assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size - ); - - module.bytes_of_mat_znx_dft(rows, rank_in, rank_out + 1, size) - } -} - -impl Infos for GGLWECiphertext { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGLWECiphertext { - pub fn rank(&self) -> usize { - self.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.digits - } - - pub fn rank_in(&self) -> usize { - self.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.data.cols_out() - 1 - } -} - -impl GGLWECiphertext, FFT64> { - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - let size = k.div_ceil(basek); - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - + module.bytes_of_vec_znx(rank + 1, size) - + module.bytes_of_vec_znx(1, size) - + module.bytes_of_vec_znx_dft(rank + 1, size) - } - - pub fn generate_from_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { - unimplemented!() - } -} - -impl + AsRef<[u8]>> GGLWECiphertext { - pub fn encrypt_sk, DataSk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: &ScalarZnx, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank_in(), pt.cols()); - assert_eq!(self.rank_out(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(sk.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert!( - scratch.available() - >= GGLWECiphertext::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()), - "scratch.available: {} < GGLWECiphertext::generate_from_sk_scratch_space(module, self.rank()={}, \ - self.size()={}): {}", - scratch.available(), - self.rank(), - self.size(), - GGLWECiphertext::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()) - ); - assert!( - self.rows() * self.digits() * self.basek() <= self.k(), - "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}", - self.rows(), - self.digits(), - self.basek(), - self.rows() * self.digits() * self.basek(), - self.k() - ); - } - - let rows: usize = self.rows(); - let digits: usize = self.digits(); - let basek: usize = self.basek(); - let k: usize = self.k(); - let rank_in: usize = self.rank_in(); - let rank_out: usize = self.rank_out(); - - let (mut tmp_pt, scrach_1) = scratch.tmp_glwe_pt(module, basek, k); - let (mut tmp_ct, scrach_2) = scrach_1.tmp_glwe_ct(module, basek, k, rank_out); - let (mut tmp_ct_dft, scratch_3) = scrach_2.tmp_glwe_fourier(module, basek, k, rank_out); - - // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns - // - // Example for ksk rank 2 to rank 3: - // - // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) - // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) - // - // Example ksk rank 2 to rank 1 - // - // (-(a*s) + s0, a) - // (-(b*s) + s1, b) - (0..rank_in).for_each(|col_i| { - (0..rows).for_each(|row_i| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace( - &mut tmp_pt.data, - 0, - (digits - 1) + row_i * digits, - pt, - col_i, - ); - module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3); - - // rlwe encrypt of vec_znx_pt into vec_znx_ct - tmp_ct.encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, sigma, scratch_3); - - // Switch vec_znx_ct into DFT domain - tmp_ct.dft(module, &mut tmp_ct_dft); - - // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft - self.set_row(module, row_i, col_i, &tmp_ct_dft); - }); - }); - } -} - -impl> GetRow for GGLWECiphertext { - fn get_row + AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut GLWECiphertextFourier, - ) { - module.vmp_extract_row(&mut res.data, &self.data, row_i, col_j); - } -} - -impl + AsRef<[u8]>> SetRow for GGLWECiphertext { - fn set_row>( - &mut self, - module: &Module, - row_i: usize, - col_j: usize, - a: &GLWECiphertextFourier, - ) { - module.vmp_prepare_row(&mut self.data, row_i, col_j, &a.data); - } -} diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw/ciphertext.rs similarity index 91% rename from core/src/ggsw_ciphertext.rs rename to core/src/ggsw/ciphertext.rs index 82e0e81..8239617 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw/ciphertext.rs @@ -6,8 +6,8 @@ use backend::{ use sampling::source::Source; use crate::{ - AutomorphismKey, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow, - TensorKey, + FourierGLWECiphertext, FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWESwitchingKey, GLWETensorKey, GetRow, + Infos, ScratchCore, SetRow, }; pub struct GGSWCiphertext { @@ -17,8 +17,8 @@ pub struct GGSWCiphertext { pub(crate) digits: usize, } -impl GGSWCiphertext, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { +impl GGSWCiphertext, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let size: usize = k.div_ceil(basek); debug_assert!( size > digits, @@ -220,9 +220,9 @@ impl GGSWCiphertext, FFT64> { digits: usize, rank: usize, ) -> usize { - let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank); - let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); - let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); + let tmp_in: usize = FourierGLWECiphertext::bytes_of(module, basek, k_in, rank); + let tmp_out: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); + let ggsw: usize = FourierGLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); tmp_in + tmp_out + ggsw } @@ -234,9 +234,9 @@ impl GGSWCiphertext, FFT64> { digits: usize, rank: usize, ) -> usize { - let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let tmp: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank); let ggsw: usize = - GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); + FourierGLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); tmp + ggsw } } @@ -246,7 +246,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { &mut self, module: &Module, pt: &ScalarZnx, - sk: &GLWESecret, + sk: &FourierGLWESecret, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -290,7 +290,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { // Switch vec_znx_ct into DFT domain { - let (mut tmp_ct_dft, _) = scratch2.tmp_glwe_fourier(module, basek, k, rank); + let (mut tmp_ct_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, k, rank); tmp_ct.dft(module, &mut tmp_ct_dft); self.set_row(module, row_i, col_j, &tmp_ct_dft); } @@ -304,7 +304,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { col_j: usize, res: &mut R, ci_dft: &VecZnxDft, - tsk: &TensorKey, + tsk: &GLWETensorKey, scratch: &mut Scratch, ) where R: VecZnxToMut, @@ -361,7 +361,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { // = // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) (1..cols).for_each(|col_i| { - let pmat: &MatZnxDft = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j]) + let pmat: &MatZnxDft = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) // Extracts a[i] and multipies with Enc(s[i]s[j]) (0..digits).for_each(|di| { @@ -408,7 +408,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { module: &Module, lhs: &GGSWCiphertext, ksk: &GLWESwitchingKey, - tsk: &TensorKey, + tsk: &GLWETensorKey, scratch: &mut Scratch, ) { let rank: usize = self.rank(); @@ -429,7 +429,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i); }); - module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); + module.mat_znx_dft_set_row(&mut self.data, row_i, 0, &ci_dft); // Generates // @@ -438,7 +438,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) (1..cols).for_each(|col_j| { self.expand_row(module, col_j, &mut tmp_res.data, &ci_dft, tsk, scratch2); - let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank); + let (mut tmp_res_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, self.k(), rank); tmp_res.dft(module, &mut tmp_res_dft); self.set_row(module, row_i, col_j, &tmp_res_dft); }); @@ -449,7 +449,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { &mut self, module: &Module, ksk: &GLWESwitchingKey, - tsk: &TensorKey, + tsk: &GLWETensorKey, scratch: &mut Scratch, ) { unsafe { @@ -462,8 +462,8 @@ impl + AsRef<[u8]>> GGSWCiphertext { &mut self, module: &Module, lhs: &GGSWCiphertext, - auto_key: &AutomorphismKey, - tensor_key: &TensorKey, + auto_key: &GLWEAutomorphismKey, + tensor_key: &GLWETensorKey, scratch: &mut Scratch, ) { #[cfg(debug_assertions)] @@ -525,7 +525,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i); }); - module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); + module.mat_znx_dft_set_row(&mut self.data, row_i, 0, &ci_dft); // Generates // @@ -541,7 +541,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { tensor_key, scratch2, ); - let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank); + let (mut tmp_res_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, self.k(), rank); tmp_res.dft(module, &mut tmp_res_dft); self.set_row(module, row_i, col_j, &tmp_res_dft); }); @@ -551,8 +551,8 @@ impl + AsRef<[u8]>> GGSWCiphertext { pub fn automorphism_inplace, DataTsk: AsRef<[u8]>>( &mut self, module: &Module, - auto_key: &AutomorphismKey, - tensor_key: &TensorKey, + auto_key: &GLWEAutomorphismKey, + tensor_key: &GLWETensorKey, scratch: &mut Scratch, ) { unsafe { @@ -599,8 +599,8 @@ impl + AsRef<[u8]>> GGSWCiphertext { ) } - let (mut tmp_ct_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_ct_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_ct_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_ct_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank() + 1).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -636,7 +636,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { ); } - let (mut tmp_ct, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_ct, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank() + 1).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -674,7 +674,7 @@ impl> GGSWCiphertext { ) ) } - let (mut tmp_dft_dft, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_dft_dft, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); self.get_row(module, row_i, 0, &mut tmp_dft_dft); res.keyswitch_from_fourier(module, &tmp_dft_dft, ksk, scratch1); } @@ -686,9 +686,9 @@ impl> GetRow for GGSWCiphertext { module: &Module, row_i: usize, col_j: usize, - res: &mut GLWECiphertextFourier, + res: &mut FourierGLWECiphertext, ) { - module.vmp_extract_row(&mut res.data, &self.data, row_i, col_j); + module.mat_znx_dft_get_row(&mut res.data, &self.data, row_i, col_j); } } @@ -698,8 +698,8 @@ impl + AsRef<[u8]>> SetRow for GGSWCiphertext, row_i: usize, col_j: usize, - a: &GLWECiphertextFourier, + a: &FourierGLWECiphertext, ) { - module.vmp_prepare_row(&mut self.data, row_i, col_j, &a.data); + module.mat_znx_dft_set_row(&mut self.data, row_i, col_j, &a.data); } } diff --git a/core/src/ggsw/mod.rs b/core/src/ggsw/mod.rs new file mode 100644 index 0000000..f27b96b --- /dev/null +++ b/core/src/ggsw/mod.rs @@ -0,0 +1,6 @@ +pub mod ciphertext; + +pub use ciphertext::GGSWCiphertext; + +#[cfg(test)] +mod test_fft64; diff --git a/core/src/test_fft64/ggsw.rs b/core/src/ggsw/test_fft64/ggsw.rs similarity index 82% rename from core/src/test_fft64/ggsw.rs rename to core/src/ggsw/test_fft64/ggsw.rs index 3e6ee8b..dc84eb6 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/ggsw/test_fft64/ggsw.rs @@ -5,9 +5,9 @@ use backend::{ use sampling::source::Source; use crate::{ - GGSWCiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, TensorKey, - automorphism::AutomorphismKey, - test_fft64::{noise_ggsw_keyswitch, noise_ggsw_product}, + FourierGLWECiphertext, FourierGLWESecret, GGSWCiphertext, GLWEAutomorphismKey, GLWEPlaintext, GLWESecret, GLWESwitchingKey, + GLWETensorKey, GetRow, Infos, + noise::{noise_ggsw_keyswitch, noise_ggsw_product}, }; #[test] @@ -139,23 +139,24 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, digits: usize, rank: us let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k), + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct.encrypt_sk( &module, &pt_scalar, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); @@ -172,14 +173,14 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, digits: usize, rank: us // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -209,7 +210,7 @@ fn test_keyswitch( let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_in, rows, digits_in, rank); let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_out, rows, digits_in, rank); - let mut tsk: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut tsk: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_ksk, rows, digits, rank); let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); @@ -221,9 +222,9 @@ fn test_keyswitch( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank, rank) + | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_scratch_space( &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), @@ -231,24 +232,26 @@ fn test_keyswitch( let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - tsk.generate_from_sk( + tsk.encrypt_sk( &module, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -260,7 +263,7 @@ fn test_keyswitch( ct_in.encrypt_sk( &module, &pt_scalar, - &sk_in, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, @@ -269,7 +272,7 @@ fn test_keyswitch( ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); @@ -280,14 +283,14 @@ fn test_keyswitch( // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_out.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -336,7 +339,7 @@ fn test_keyswitch_inplace( let digits_in: usize = 1; let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows, digits_in, rank); - let mut tsk: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); + let mut tsk: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); @@ -348,32 +351,34 @@ fn test_keyswitch_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank, rank) + | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); - ksk.generate_from_sk( + ksk.encrypt_sk( &module, &sk_in, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - tsk.generate_from_sk( + tsk.encrypt_sk( &module, - &sk_out, + &sk_out_dft, &mut source_xa, &mut source_xe, sigma, @@ -385,7 +390,7 @@ fn test_keyswitch_inplace( ct.encrypt_sk( &module, &pt_scalar, - &sk_in, + &sk_in_dft, &mut source_xa, &mut source_xe, sigma, @@ -394,7 +399,7 @@ fn test_keyswitch_inplace( ct.keyswitch_inplace(&module, &ksk, &tsk, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); @@ -411,14 +416,14 @@ fn test_keyswitch_inplace( // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_out.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -471,8 +476,8 @@ fn test_automorphism( let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_in, rows_in, digits_in, rank); let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_out, rows_in, digits_in, rank); - let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); - let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut tensor_key: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); + let mut auto_key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -483,9 +488,9 @@ fn test_automorphism( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) + | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_scratch_space( &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), @@ -493,10 +498,11 @@ fn test_automorphism( let var_xs: f64 = 0.5; - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - auto_key.generate_from_sk( + auto_key.encrypt_sk( &module, p, &sk, @@ -505,9 +511,9 @@ fn test_automorphism( sigma, scratch.borrow(), ); - tensor_key.generate_from_sk( + tensor_key.encrypt_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -519,7 +525,7 @@ fn test_automorphism( ct_in.encrypt_sk( &module, &pt_scalar, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -530,7 +536,7 @@ fn test_automorphism( module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); @@ -541,14 +547,14 @@ fn test_automorphism( // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -596,8 +602,8 @@ fn test_automorphism_inplace( let digits_in: usize = 1; let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows_in, digits_in, rank); - let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); - let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut tensor_key: GLWETensorKey, FFT64> = GLWETensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); + let mut auto_key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -608,18 +614,19 @@ fn test_automorphism_inplace( let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) + | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWETensorKey::encrypt_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, var_xs, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); - auto_key.generate_from_sk( + auto_key.encrypt_sk( &module, p, &sk, @@ -628,9 +635,9 @@ fn test_automorphism_inplace( sigma, scratch.borrow(), ); - tensor_key.generate_from_sk( + tensor_key.encrypt_sk( &module, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -642,7 +649,7 @@ fn test_automorphism_inplace( ct.encrypt_sk( &module, &pt_scalar, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -653,7 +660,7 @@ fn test_automorphism_inplace( module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); @@ -664,14 +671,14 @@ fn test_automorphism_inplace( // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -737,18 +744,19 @@ fn test_external_product( pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_out) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank) | GGSWCiphertext::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw_rhs.encrypt_sk( &module, &pt_ggsw_rhs, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -758,7 +766,7 @@ fn test_external_product( ct_ggsw_lhs_in.encrypt_sk( &module, &pt_ggsw_lhs, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -767,7 +775,7 @@ fn test_external_product( ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); @@ -787,13 +795,13 @@ fn test_external_product( if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt.data, 0, &pt_want.data, 0); @@ -857,18 +865,19 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank) | GGSWCiphertext::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); ct_ggsw_rhs.encrypt_sk( &module, &pt_ggsw_rhs, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -878,7 +887,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw ct_ggsw_lhs.encrypt_sk( &module, &pt_ggsw_lhs, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -887,7 +896,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut ct_glwe_fourier: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size()); @@ -907,13 +916,13 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw if col_j > 0 { module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); } ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt, &sk, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt.data, 0, &pt_want.data, 0); diff --git a/core/src/ggsw/test_fft64/mod.rs b/core/src/ggsw/test_fft64/mod.rs new file mode 100644 index 0000000..3326f10 --- /dev/null +++ b/core/src/ggsw/test_fft64/mod.rs @@ -0,0 +1 @@ +mod ggsw; diff --git a/core/src/glwe/automorphism.rs b/core/src/glwe/automorphism.rs new file mode 100644 index 0000000..1513362 --- /dev/null +++ b/core/src/glwe/automorphism.rs @@ -0,0 +1,121 @@ +use backend::{FFT64, Module, Scratch, VecZnxOps}; + +use crate::{GLWEAutomorphismKey, GLWECiphertext}; + +impl GLWECiphertext> { + pub fn automorphism_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + } + + pub fn automorphism_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn automorphism, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + self.keyswitch(module, lhs, &rhs.key, scratch); + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); + }) + } + + pub fn automorphism_inplace>( + &mut self, + module: &Module, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + self.keyswitch_inplace(module, &rhs.key, scratch); + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); + }) + } + + pub fn automorphism_add, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, lhs, &rhs.key, scratch); + } + + pub fn automorphism_add_inplace>( + &mut self, + module: &Module, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + } + } + + pub fn automorphism_sub_ab, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, lhs, &rhs.key, scratch); + } + + pub fn automorphism_sub_ab_inplace>( + &mut self, + module: &Module, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + } + } + + pub fn automorphism_sub_ba, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, lhs, &rhs.key, scratch); + } + + pub fn automorphism_sub_ba_inplace>( + &mut self, + module: &Module, + rhs: &GLWEAutomorphismKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + } + } +} diff --git a/core/src/glwe/ciphertext.rs b/core/src/glwe/ciphertext.rs new file mode 100644 index 0000000..d0fb39c --- /dev/null +++ b/core/src/glwe/ciphertext.rs @@ -0,0 +1,115 @@ +use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxDftOps, VecZnxToMut, VecZnxToRef}; + +use crate::{FourierGLWECiphertext, GLWEOps, Infos, SetMetaData}; + +pub struct GLWECiphertext { + pub data: VecZnx, + pub basek: usize, + pub k: usize, +} + +impl GLWECiphertext> { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx(rank + 1, k.div_ceil(basek)), + basek, + k, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) + } +} + +impl Infos for GLWECiphertext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWECiphertext { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl> GLWECiphertext { + #[allow(dead_code)] + pub(crate) fn dft + AsRef<[u8]>>(&self, module: &Module, res: &mut FourierGLWECiphertext) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_dft(1, 0, &mut res.data, i, &self.data, i); + }) + } +} + +impl> GLWECiphertext { + pub fn clone(&self) -> GLWECiphertext> { + GLWECiphertext { + data: self.data.clone(), + basek: self.basek(), + k: self.k(), + } + } +} + +impl + AsRef<[u8]>> SetMetaData for GLWECiphertext { + fn set_k(&mut self, k: usize) { + self.k = k + } + + fn set_basek(&mut self, basek: usize) { + self.basek = basek + } +} + +pub trait GLWECiphertextToRef { + fn to_ref(&self) -> GLWECiphertext<&[u8]>; +} + +impl> GLWECiphertextToRef for GLWECiphertext { + fn to_ref(&self) -> GLWECiphertext<&[u8]> { + GLWECiphertext { + data: self.data.to_ref(), + basek: self.basek, + k: self.k, + } + } +} + +pub trait GLWECiphertextToMut { + fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; +} + +impl + AsRef<[u8]>> GLWECiphertextToMut for GLWECiphertext { + fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { + GLWECiphertext { + data: self.data.to_mut(), + basek: self.basek, + k: self.k, + } + } +} + +impl GLWEOps for GLWECiphertext +where + D: AsRef<[u8]> + AsMut<[u8]>, + GLWECiphertext: GLWECiphertextToMut + Infos + SetMetaData, +{ +} diff --git a/core/src/glwe/decryption.rs b/core/src/glwe/decryption.rs new file mode 100644 index 0000000..e543963 --- /dev/null +++ b/core/src/glwe/decryption.rs @@ -0,0 +1,58 @@ +use backend::{ + FFT64, Module, ScalarZnxDftOps, Scratch, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, + ZnxZero, +}; + +use crate::{FourierGLWESecret, GLWECiphertext, GLWEPlaintext, Infos}; + +impl GLWECiphertext> { + pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + let size: usize = k.div_ceil(basek); + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } +} + +impl> GLWECiphertext { + pub fn decrypt + AsRef<[u8]>, DataSk: AsRef<[u8]>>( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk: &FourierGLWESecret, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk.n(), module.n()); + } + + let cols: usize = self.rank() + 1; + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + c0_big.zero(); + + { + (1..cols).for_each(|i| { + // ci_dft = DFT(a[i]) * DFT(s[i]) + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); + module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1); + let ci_big = module.vec_znx_idft_consume(ci_dft); + + // c0_big += a[i] * s[i] + module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); + }); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0); + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut c0_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } +} diff --git a/core/src/glwe/encryption.rs b/core/src/glwe/encryption.rs new file mode 100644 index 0000000..b0a7615 --- /dev/null +++ b/core/src/glwe/encryption.rs @@ -0,0 +1,254 @@ +use backend::{ + AddNormal, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxAlloc, VecZnxBig, + VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero, +}; +use sampling::source::Source; + +use crate::{FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, Infos, SIX_SIGMA, dist::Distribution}; + +impl GLWECiphertext> { + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + let size: usize = k.div_ceil(basek); + module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size) + } + pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + let size: usize = k.div_ceil(basek); + ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn encrypt_sk, DataSk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + self.encrypt_sk_private( + module, + Some((pt, 0)), + sk, + source_xa, + source_xe, + sigma, + scratch, + ); + } + + pub fn encrypt_zero_sk>( + &mut self, + module: &Module, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + self.encrypt_sk_private( + module, + None::<(&GLWEPlaintext>, usize)>, + sk, + source_xa, + source_xe, + sigma, + scratch, + ); + } + + pub fn encrypt_pk, DataPk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + self.encrypt_pk_private::( + module, + Some((pt, 0)), + pk, + source_xu, + source_xe, + sigma, + scratch, + ); + } + + pub fn encrypt_zero_pk>( + &mut self, + module: &Module, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + self.encrypt_pk_private::, DataPk>( + module, + None::<(&GLWEPlaintext>, usize)>, + pk, + source_xu, + source_xe, + sigma, + scratch, + ); + } + + pub(crate) fn encrypt_sk_private, DataSk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk.rank()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.n(), module.n()); + if let Some((pt, col)) = pt { + assert_eq!(pt.n(), module.n()); + assert!(col < self.rank() + 1); + } + assert!( + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), + "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", + scratch.available(), + GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) + ) + } + + let basek: usize = self.basek(); + let k: usize = self.k(); + let size: usize = self.size(); + let cols: usize = self.rank() + 1; + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + c0_big.zero(); + + { + // c[i] = uniform + // c[0] -= c[i] * s[i], + (1..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size); + + // c[i] = uniform + self.data.fill_uniform(basek, i, size, source_xa); + + // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) + module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); + module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + + // use c[0] as buffer, which is overwritten later by the normalization step + module.vec_znx_big_normalize(basek, &mut self.data, 0, &ci_big, 0, scratch_2); + + // c0_tmp = -c[i] * s[i] (use c[0] as buffer) + module.vec_znx_sub_ab_inplace(&mut c0_big, 0, &self.data, 0); + + // c[i] += m if col = i + if let Some((pt, col)) = pt { + if i == col { + module.vec_znx_add_inplace(&mut self.data, i, &pt.data, 0); + module.vec_znx_normalize_inplace(basek, &mut self.data, i, scratch_2); + } + } + }); + } + + // c[0] += e + c0_big.add_normal(basek, 0, k, source_xe, sigma, sigma * SIX_SIGMA); + + // c[0] += m if col = 0 + if let Some((pt, col)) = pt { + if col == 0 { + module.vec_znx_add_inplace(&mut c0_big, 0, &pt.data, 0); + } + } + + // c[0] = norm(c[0]) + module.vec_znx_normalize(basek, &mut self.data, 0, &c0_big, 0, scratch_1); + } + + pub(crate) fn encrypt_pk_private, DataPk: AsRef<[u8]>>( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(self.basek(), pk.basek()); + assert_eq!(self.n(), module.n()); + assert_eq!(pk.n(), module.n()); + assert_eq!(self.rank(), pk.rank()); + if let Some((pt, _)) = pt { + assert_eq!(pt.basek(), pk.basek()); + assert_eq!(pt.n(), module.n()); + } + } + + let basek: usize = pk.basek(); + let size_pk: usize = pk.size(); + let cols: usize = self.rank() + 1; + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + match pk.dist { + Distribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ + Self::generate" + ), + Distribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + Distribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + Distribution::BinaryFixed(hw) => u.fill_binary_hw(0, hw, source_xu), + Distribution::BinaryProb(prob) => u.fill_binary_prob(0, prob, source_xu), + Distribution::BinaryBlock(block_size) => u.fill_binary_block(0, block_size, source_xu), + Distribution::ZERO => {} + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + // ct[i] = pk[i] * u + ei (+ m if col = i) + (0..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk); + // ci_dft = DFT(u) * DFT(pk[i]) + module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data.data, i); + + // ci_big = u * p[i] + let mut ci_big = module.vec_znx_idft_consume(ci_dft); + + // ci_big = u * pk[i] + e + ci_big.add_normal(basek, 0, pk.k(), source_xe, sigma, sigma * SIX_SIGMA); + + // ci_big = u * pk[i] + e + m (if col = i) + if let Some((pt, col)) = pt { + if col == i { + module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0); + } + } + + // ct[i] = norm(ci_big) + module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2); + }); + } +} diff --git a/core/src/glwe/external_product.rs b/core/src/glwe/external_product.rs new file mode 100644 index 0000000..e7ee778 --- /dev/null +++ b/core/src/glwe/external_product.rs @@ -0,0 +1,129 @@ +use backend::{ + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxBig, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxScratch, +}; + +use crate::{FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, Infos}; + +impl GLWECiphertext> { + pub fn external_product_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_ggsw, rank); + let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); + let out_size: usize = k_out.div_ceil(basek); + let ggsw_size: usize = k_ggsw.div_ceil(basek); + let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + + module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, // rows + rank + 1, // cols in + rank + 1, // cols out + ggsw_size, + ); + let normalize: usize = module.vec_znx_normalize_tmp_bytes(); + res_dft + (vmp | normalize) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn external_product, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::external_product_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank(), + ) + ); + } + + let cols: usize = rhs.rank() + 1; + let digits: usize = rhs.digits(); + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); + + { + (0..digits).for_each(|di| { + // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits + a_dft.set_size((lhs.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + + (0..cols).for_each(|col_i| { + module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); + } + }); + } + + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); + }); + } + + pub fn external_product_inplace>( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.external_product(&module, &*self_ptr, rhs, scratch); + } + } +} diff --git a/core/src/glwe/keyswitch.rs b/core/src/glwe/keyswitch.rs new file mode 100644 index 0000000..ca6dcad --- /dev/null +++ b/core/src/glwe/keyswitch.rs @@ -0,0 +1,256 @@ +use backend::{ + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, + VecZnxDftOps, ZnxZero, +}; + +use crate::{FourierGLWECiphertext, GLWECiphertext, GLWESwitchingKey, Infos}; + +impl GLWECiphertext> { + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_out, rank_out + 1); + let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); + let out_size: usize = k_out.div_ceil(basek); + let ksk_size: usize = k_ksk.div_ceil(basek); + let ai_dft: usize = module.bytes_of_vec_znx_dft(rank_in, in_size); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) + + module.bytes_of_vec_znx_dft(rank_in, in_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + return res_dft + ((ai_dft + vmp) | normalize); + } + + pub fn keyswitch_from_fourier_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + basek: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn keyswitch, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + Self::keyswitch_private::<_, _, 0>(self, 0, module, lhs, rhs, scratch); + } + + pub fn keyswitch_inplace>( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } + } + + pub(crate) fn keyswitch_private, DataRhs: AsRef<[u8]>, const OP: u8>( + &mut self, + apply_auto: i64, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!( + lhs.rank(), + rhs.rank_in(), + "lhs.rank(): {} != rhs.rank_in(): {}", + lhs.rank(), + rhs.rank_in() + ); + assert_eq!( + self.rank(), + rhs.rank_out(), + "self.rank(): {} != rhs.rank_out(): {}", + self.rank(), + rhs.rank_out() + ); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + let digits: usize = rhs.digits(); + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + digits - 1) / digits); + ai_dft.zero(); + { + (0..digits).for_each(|di| { + ai_dft.set_size((lhs.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft( + digits, + digits - di - 1, + &mut ai_dft, + col_i, + &lhs.data, + col_i + 1, + ); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &ai_dft, &rhs.key.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.key.data, di, scratch2); + } + }); + } + + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_add_small_inplace(&mut res_big, 0, &lhs.data, 0); + + (0..cols_out).for_each(|i| { + if apply_auto != 0 { + module.vec_znx_big_automorphism_inplace(apply_auto, &mut res_big, i); + } + + match OP { + 1 => module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i), + 2 => module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i), + 3 => module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i), + _ => {} + } + module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); + }); + } + + pub(crate) fn keyswitch_from_fourier, DataRhs: AsRef<[u8]>>( + &mut self, + module: &Module, + lhs: &FourierGLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_from_fourier_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + // Buffer of the result of VMP in DFT + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + let digits = rhs.digits(); + + (0..digits).for_each(|di| { + // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + di) / digits); + + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft_copy( + digits, + digits - 1 - di, + &mut ai_dft, + col_i, + &lhs.data, + col_i + 1, + ); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &ai_dft, &rhs.key.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.key.data, di, scratch2); + } + }); + } + + module.vec_znx_dft_add_inplace(&mut res_dft, 0, &lhs.data, 0); + + // Switches result of VMP outside of DFT + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); + }); + } +} diff --git a/core/src/glwe/mod.rs b/core/src/glwe/mod.rs new file mode 100644 index 0000000..e3879cd --- /dev/null +++ b/core/src/glwe/mod.rs @@ -0,0 +1,23 @@ +pub mod automorphism; +pub mod ciphertext; +pub mod decryption; +pub mod encryption; +pub mod external_product; +pub mod keyswitch; +pub mod ops; +pub mod packing; +pub mod plaintext; +pub mod public_key; +pub mod secret; +pub mod trace; + +pub use ciphertext::GLWECiphertext; +pub(crate) use ciphertext::{GLWECiphertextToMut, GLWECiphertextToRef}; +pub use ops::GLWEOps; +pub use packing::GLWEPacker; +pub use plaintext::GLWEPlaintext; +pub use public_key::GLWEPublicKey; +pub use secret::GLWESecret; + +#[cfg(test)] +mod test_fft64; diff --git a/core/src/glwe_ops.rs b/core/src/glwe/ops.rs similarity index 100% rename from core/src/glwe_ops.rs rename to core/src/glwe/ops.rs diff --git a/core/src/glwe_packing.rs b/core/src/glwe/packing.rs similarity index 96% rename from core/src/glwe_packing.rs rename to core/src/glwe/packing.rs index 85aceb6..3496994 100644 --- a/core/src/glwe_packing.rs +++ b/core/src/glwe/packing.rs @@ -1,4 +1,4 @@ -use crate::{AutomorphismKey, GLWECiphertext, GLWEOps, Infos, ScratchCore}; +use crate::{GLWEAutomorphismKey, GLWECiphertext, GLWEOps, Infos, ScratchCore}; use std::collections::HashMap; use backend::{FFT64, Module, Scratch}; @@ -7,7 +7,7 @@ use backend::{FFT64, Module, Scratch}; /// with constant memory of Log(N) ciphertexts. /// Main difference with usual GLWE packing is that /// the output is bit-reversed. -pub struct StreamPacker { +pub struct GLWEPacker { accumulators: Vec, log_batch: usize, counter: usize, @@ -39,7 +39,7 @@ impl Accumulator { } } -impl StreamPacker { +impl GLWEPacker { /// Instantiates a new [StreamPacker]. /// /// #Arguments @@ -98,7 +98,7 @@ impl StreamPacker { module: &Module, res: &mut Vec>>, a: Option<&GLWECiphertext>, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { pack_core( @@ -125,7 +125,7 @@ impl StreamPacker { &mut self, module: &Module, res: &mut Vec>>, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { if self.counter != 0 { @@ -151,7 +151,7 @@ fn pack_core, DataAK: AsRef<[u8]>>( a: Option<&GLWECiphertext>, accumulators: &mut [Accumulator], i: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { let log_n: usize = module.log_n(); @@ -215,7 +215,7 @@ fn combine, DataAK: AsRef<[u8]>>( acc: &mut Accumulator, b: Option<&GLWECiphertext>, i: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { let log_n: usize = module.log_n(); diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe/plaintext.rs similarity index 100% rename from core/src/glwe_plaintext.rs rename to core/src/glwe/plaintext.rs diff --git a/core/src/glwe/public_key.rs b/core/src/glwe/public_key.rs new file mode 100644 index 0000000..f4871ad --- /dev/null +++ b/core/src/glwe/public_key.rs @@ -0,0 +1,75 @@ +use backend::{Backend, FFT64, Module, ScratchOwned, VecZnxDft}; +use sampling::source::Source; + +use crate::{FourierGLWECiphertext, FourierGLWESecret, Infos, dist::Distribution}; + +pub struct GLWEPublicKey { + pub(crate) data: FourierGLWECiphertext, + pub(crate) dist: Distribution, +} + +impl GLWEPublicKey, B> { + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: FourierGLWECiphertext::alloc(module, basek, k, rank), + dist: Distribution::NONE, + } + } + + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + FourierGLWECiphertext::, B>::bytes_of(module, basek, k, rank) + } +} + +impl Infos for GLWEPublicKey { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data.data + } + + fn basek(&self) -> usize { + self.data.basek + } + + fn k(&self) -> usize { + self.data.k + } +} + +impl GLWEPublicKey { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl + AsMut<[u8]>> GLWEPublicKey { + pub fn generate_from_sk>( + &mut self, + module: &Module, + sk: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + ) { + #[cfg(debug_assertions)] + { + match sk.dist { + Distribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), + _ => {} + } + } + + // Its ok to allocate scratch space here since pk is usually generated only once. + let mut scratch: ScratchOwned = ScratchOwned::new(FourierGLWECiphertext::encrypt_sk_scratch_space( + module, + self.basek(), + self.k(), + self.rank(), + )); + + self.data + .encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch.borrow()); + self.dist = sk.dist; + } +} diff --git a/core/src/glwe/secret.rs b/core/src/glwe/secret.rs new file mode 100644 index 0000000..5073d2b --- /dev/null +++ b/core/src/glwe/secret.rs @@ -0,0 +1,84 @@ +use backend::{Backend, Module, ScalarZnx, ScalarZnxAlloc, ZnxInfos, ZnxZero}; +use sampling::source::Source; + +use crate::dist::Distribution; + +pub struct GLWESecret { + pub(crate) data: ScalarZnx, + pub(crate) dist: Distribution, +} + +impl GLWESecret> { + pub fn alloc(module: &Module, rank: usize) -> Self { + Self { + data: module.new_scalar_znx(rank), + dist: Distribution::NONE, + } + } + + pub fn bytes_of(module: &Module, rank: usize) -> usize { + module.bytes_of_scalar_znx(rank) + } +} + +impl GLWESecret { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl + AsRef<[u8]>> GLWESecret { + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_prob(i, prob, source); + }); + self.dist = Distribution::TernaryProb(prob); + } + + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_hw(i, hw, source); + }); + self.dist = Distribution::TernaryFixed(hw); + } + + pub fn fill_binary_prob(&mut self, prob: f64, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_prob(i, prob, source); + }); + self.dist = Distribution::BinaryProb(prob); + } + + pub fn fill_binary_hw(&mut self, hw: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_hw(i, hw, source); + }); + self.dist = Distribution::BinaryFixed(hw); + } + + pub fn fill_binary_block(&mut self, block_size: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_binary_block(i, block_size, source); + }); + self.dist = Distribution::BinaryBlock(block_size); + } + + pub fn fill_zero(&mut self) { + self.data.zero(); + self.dist = Distribution::ZERO; + } + + // pub(crate) fn prep_fourier(&mut self, module: &Module) { + // (0..self.rank()).for_each(|i| { + // module.svp_prepare(&mut self.data_fourier, i, &self.data, i); + // }); + // } +} diff --git a/core/src/glwe/test_fft64/automorphism.rs b/core/src/glwe/test_fft64/automorphism.rs new file mode 100644 index 0000000..0b917ef --- /dev/null +++ b/core/src/glwe/test_fft64/automorphism.rs @@ -0,0 +1,223 @@ +use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; + +use sampling::source::Source; + +use crate::{ + FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::log2_std_noise_gglwe_product, +}; + +#[test] +fn apply_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!("test automorphism_inplace digits: {} rank: {}", di, rank); + test_automorphism_inplace(log_n, basek, -5, k_ct, k_ksk, di, rank, 3.2); + }); + }); +} + +#[test] +fn apply() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_out: usize = k_ksk; // Better capture noise. + println!("test automorphism digits: {} rank: {}", di, rank); + test_automorphism(log_n, basek, -5, k_out, k_in, k_ksk, di, rank, 3.2); + }) + }); +} + +fn test_automorphism( + log_n: usize, + basek: usize, + p: i64, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = k_in.div_ceil(basek * digits); + + let mut autokey: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + | GLWECiphertext::automorphism_scratch_space( + &module, + basek, + ct_out.k(), + ct_in.k(), + autokey.k(), + digits, + rank, + ), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + autokey.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); + ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + println!("{}", noise_have); + + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_in, + k_ksk, + ); + + assert!( + noise_have <= noise_want + 1.0, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_automorphism_inplace( + log_n: usize, + basek: usize, + p: i64, + k_ct: usize, + k_ksk: usize, + digits: usize, + rank: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = k_ct.div_ceil(basek * digits); + + let mut autokey: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::automorphism_inplace_scratch_space(&module, basek, ct.k(), autokey.k(), digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + autokey.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.automorphism_inplace(&module, &autokey, scratch.borrow()); + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/glwe/test_fft64/encryption.rs b/core/src/glwe/test_fft64/encryption.rs new file mode 100644 index 0000000..2d82575 --- /dev/null +++ b/core/src/glwe/test_fft64/encryption.rs @@ -0,0 +1,183 @@ +use backend::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; +use itertools::izip; +use sampling::source::Source; + +use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos}; + +#[test] +fn encrypt_sk() { + let log_n: usize = 8; + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(log_n, 8, 54, 30, 3.2, rank); + }); +} + +#[test] +fn encrypt_zero_sk() { + let log_n: usize = 8; + (1..4).for_each(|rank| { + println!("test encrypt_zero_sk rank: {}", rank); + test_encrypt_zero_sk(log_n, 8, 64, 3.2, rank); + }); +} + +#[test] +fn encrypt_pk() { + let log_n: usize = 8; + (1..4).for_each(|rank| { + println!("test encrypt_pk rank: {}", rank); + test_encrypt_pk(log_n, 8, 64, 64, 3.2, rank) + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_pt); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); + + pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10); + + ct.encrypt_sk( + &module, + &pt, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt.data.zero(); + + ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + let mut data_have: Vec = vec![0i64; module.n()]; + + pt.data + .decode_vec_i64(0, basek, pt.size() * basek, &mut data_have); + + // TODO: properly assert the decryption noise through std(dec(ct) - pt) + let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64; + izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { + let b_scaled = (*b as f64) / scale; + assert!( + (*a as f64 - b_scaled).abs() < 0.1, + "{} {}", + *a as f64, + b_scaled + ) + }); +} + +fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + let mut ct_dft: FourierGLWECiphertext, FFT64> = FourierGLWECiphertext::alloc(&module, basek, k_ct, rank); + + let mut scratch: ScratchOwned = ScratchOwned::new( + FourierGLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) + | FourierGLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank), + ); + + ct_dft.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); +} + +fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::alloc(&module, basek, k_pk, rank); + pk.generate_from_sk(&module, &sk_dft, &mut source_xa, &mut source_xe, sigma); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::encrypt_pk_scratch_space(&module, basek, pk.k()), + ); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0); + + pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10); + + ct.encrypt_pk( + &module, + &pt_want, + &pk, + &mut source_xu, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); + + let noise_have: f64 = pt_want.data.std(0, basek).log2(); + let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); + + assert!( + (noise_have - noise_want).abs() < 0.2, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/glwe/test_fft64/external_product.rs b/core/src/glwe/test_fft64/external_product.rs new file mode 100644 index 0000000..e1f6b19 --- /dev/null +++ b/core/src/glwe/test_fft64/external_product.rs @@ -0,0 +1,244 @@ +use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use sampling::source::Source; + +use crate::{FourierGLWESecret, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::noise_ggsw_product}; + +#[test] +fn apply() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + let k_out: usize = k_ggsw; // Better capture noise + println!("test external_product digits: {} rank: {}", di, rank); + test_external_product(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); + }); + }); +} + +#[test] +fn apply_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); + }); + }); +} + +fn test_external_product( + log_n: usize, + basek: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = k_in.div_ceil(basek * digits); + + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.data.at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe_in.k()) + | GLWECiphertext::external_product_scratch_space( + &module, + basek, + ct_glwe_out.k(), + ct_glwe_in.k(), + ct_ggsw.k(), + digits, + rank, + ), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + ct_ggsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_out.external_product(&module, &ct_glwe_in, &ct_ggsw, scratch.borrow()); + + ct_glwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_in, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = k_ct.div_ceil(basek * digits); + + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.data.at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); + + ct_ggsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); + + ct_glwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/glwe/test_fft64/keyswitch.rs b/core/src/glwe/test_fft64/keyswitch.rs new file mode 100644 index 0000000..9142292 --- /dev/null +++ b/core/src/glwe/test_fft64/keyswitch.rs @@ -0,0 +1,226 @@ +use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps}; +use sampling::source::Source; + +use crate::{ + FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, noise::log2_std_noise_gglwe_product, +}; + +#[test] +fn apply() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_out: usize = k_ksk; // better capture noise + println!( + "test keyswitch digits: {} rank_in: {} rank_out: {}", + di, rank_in, rank_out + ); + test_keyswitch(log_n, basek, k_out, k_in, k_ksk, di, rank_in, rank_out, 3.2); + }) + }); + }); +} + +#[test] +fn apply_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 45; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!("test keyswitch_inplace digits: {} rank: {}", di, rank); + test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); + }); + }); +} + +fn test_keyswitch( + log_n: usize, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = k_in.div_ceil(basek * digits); + + let mut ksk: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + | GLWECiphertext::keyswitch_scratch_space( + &module, + basek, + ct_out.k(), + ct_in.k(), + ksk.k(), + digits, + rank_in, + rank_out, + ), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow()); + ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_in as f64, + k_in, + k_ksk, + ); + + println!("{} vs. {}", noise_have, noise_want); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = k_ct.div_ceil(basek * digits); + + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ct_grlwe.k(), rank, rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_grlwe.k(), digits, rank), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_in); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk_out); + + ct_grlwe.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); + + ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.5, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/glwe/test_fft64/mod.rs b/core/src/glwe/test_fft64/mod.rs new file mode 100644 index 0000000..e3f5425 --- /dev/null +++ b/core/src/glwe/test_fft64/mod.rs @@ -0,0 +1,6 @@ +pub mod automorphism; +pub mod encryption; +pub mod external_product; +pub mod keyswitch; +pub mod packing; +pub mod trace; diff --git a/core/src/test_fft64/glwe_packing.rs b/core/src/glwe/test_fft64/packing.rs similarity index 77% rename from core/src/test_fft64/glwe_packing.rs rename to core/src/glwe/test_fft64/packing.rs index 107461c..592d38e 100644 --- a/core/src/test_fft64/glwe_packing.rs +++ b/core/src/glwe/test_fft64/packing.rs @@ -1,11 +1,11 @@ -use crate::{AutomorphismKey, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, StreamPacker}; +use crate::{FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWESecret}; use std::collections::HashMap; use backend::{Encoding, FFT64, Module, ScratchOwned, Stats}; use sampling::source::Source; #[test] -fn packing() { +fn apply() { let log_n: usize = 5; let module: Module = Module::::new(1 << log_n); @@ -26,12 +26,13 @@ fn packing() { let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct) | GLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | StreamPacker::scratch_space(&module, basek, k_ct, k_ksk, digits, rank), + | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWEPacker::scratch_space(&module, basek, k_ct, k_ksk, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut data: Vec = vec![0i64; module.n()]; @@ -40,12 +41,12 @@ fn packing() { }); pt.data.encode_vec_i64(0, basek, pt_k, &data, 32); - let gal_els: Vec = StreamPacker::galois_elements(&module); + let gal_els: Vec = GLWEPacker::galois_elements(&module); - let mut auto_keys: HashMap, FFT64>> = HashMap::new(); + let mut auto_keys: HashMap, FFT64>> = HashMap::new(); gal_els.iter().for_each(|gal_el| { - let mut key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - key.generate_from_sk( + let mut key: GLWEAutomorphismKey, FFT64> = GLWEAutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + key.encrypt_sk( &module, *gal_el, &sk, @@ -59,14 +60,14 @@ fn packing() { let log_batch: usize = 0; - let mut packer: StreamPacker = StreamPacker::new(&module, log_batch, basek, k_ct, rank); + let mut packer: GLWEPacker = GLWEPacker::new(&module, log_batch, basek, k_ct, rank); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); ct.encrypt_sk( &module, &pt, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -79,7 +80,7 @@ fn packing() { ct.encrypt_sk( &module, &pt, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -115,7 +116,7 @@ fn packing() { }); pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32); - res_i.decrypt(&module, &mut pt, &sk, scratch.borrow()); + res_i.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); if i & 1 == 0 { pt.sub_inplace_ab(&module, &pt_want); diff --git a/core/src/test_fft64/trace.rs b/core/src/glwe/test_fft64/trace.rs similarity index 79% rename from core/src/test_fft64/trace.rs rename to core/src/glwe/test_fft64/trace.rs index 885aa90..e34e260 100644 --- a/core/src/test_fft64/trace.rs +++ b/core/src/glwe/test_fft64/trace.rs @@ -3,10 +3,12 @@ use std::collections::HashMap; use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps, ZnxView, ZnxViewMut}; use sampling::source::Source; -use crate::{AutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, test_fft64::var_noise_gglwe_product}; +use crate::{ + FourierGLWESecret, GLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, noise::var_noise_gglwe_product, +}; #[test] -fn trace_inplace() { +fn apply_inplace() { let log_n: usize = 8; (1..4).for_each(|rank| { println!("test trace_inplace rank: {}", rank); @@ -33,12 +35,13 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_autokey, rank) + | GLWEAutomorphismKey::encrypt_sk_scratch_space(&module, basek, k_autokey, rank) | GLWECiphertext::trace_inplace_scratch_space(&module, basek, ct.k(), k_autokey, digits, rank), ); - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&&module, 0.5, &mut source_xs); + let mut sk: GLWESecret> = GLWESecret::alloc(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::from(&module, &sk); let mut data_want: Vec = vec![0i64; module.n()]; @@ -53,18 +56,19 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us ct.encrypt_sk( &module, &pt_have, - &sk, + &sk_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - let mut auto_keys: HashMap, FFT64>> = HashMap::new(); + let mut auto_keys: HashMap, FFT64>> = HashMap::new(); let gal_els: Vec = GLWECiphertext::trace_galois_elements(&module); gal_els.iter().for_each(|gal_el| { - let mut key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_autokey, rows, digits, rank); - key.generate_from_sk( + let mut key: GLWEAutomorphismKey, FFT64> = + GLWEAutomorphismKey::alloc(&module, basek, k_autokey, rows, digits, rank); + key.encrypt_sk( &module, *gal_el, &sk, @@ -81,7 +85,7 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us (0..pt_want.size()).for_each(|i| pt_want.data.at_mut(0, i)[0] = pt_have.data.at(0, i)[0]); - ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); module.vec_znx_normalize_inplace(basek, &mut pt_want.data, 0, scratch.borrow()); diff --git a/core/src/trace.rs b/core/src/glwe/trace.rs similarity index 89% rename from core/src/trace.rs rename to core/src/glwe/trace.rs index 3c6a5bb..c702489 100644 --- a/core/src/trace.rs +++ b/core/src/glwe/trace.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use backend::{FFT64, Module, Scratch}; -use crate::{AutomorphismKey, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData}; +use crate::{GLWEAutomorphismKey, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData}; impl GLWECiphertext> { pub fn trace_galois_elements(module: &Module) -> Vec { @@ -51,7 +51,7 @@ where start: usize, end: usize, lhs: &GLWECiphertext, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) where GLWECiphertext: GLWECiphertextToRef + Infos, @@ -65,7 +65,7 @@ where module: &Module, start: usize, end: usize, - auto_keys: &HashMap>, + auto_keys: &HashMap>, scratch: &mut Scratch, ) { (start..end).for_each(|i| { diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs deleted file mode 100644 index 28c1724..0000000 --- a/core/src/glwe_ciphertext.rs +++ /dev/null @@ -1,883 +0,0 @@ -use backend::{ - AddNormal, Backend, FFT64, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, - ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, - VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero, -}; -use sampling::source::Source; - -use crate::{ - AutomorphismKey, GGSWCiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWEPublicKey, GLWESecret, GLWESwitchingKey, - Infos, SIX_SIGMA, SecretDistribution, SetMetaData, -}; - -pub struct GLWECiphertext { - pub data: VecZnx, - pub basek: usize, - pub k: usize, -} - -impl GLWECiphertext> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { - Self { - data: module.new_vec_znx(rank + 1, k.div_ceil(basek)), - basek, - k, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) - } -} - -impl Infos for GLWECiphertext { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GLWECiphertext { - pub fn rank(&self) -> usize { - self.cols() - 1 - } -} - -impl> GLWECiphertext { - #[allow(dead_code)] - pub(crate) fn dft + AsRef<[u8]>>(&self, module: &Module, res: &mut GLWECiphertextFourier) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), res.rank()); - assert_eq!(self.basek(), res.basek()) - } - - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_dft(1, 0, &mut res.data, i, &self.data, i); - }) - } -} - -impl GLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = k.div_ceil(basek); - module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size) - } - pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = k.div_ceil(basek); - ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1)) - + module.bytes_of_scalar_znx_dft(1) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = k.div_ceil(basek); - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } - - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank_out + 1); - let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); - let out_size: usize = k_out.div_ceil(basek); - let ksk_size: usize = k_ksk.div_ceil(basek); - let ai_dft: usize = module.bytes_of_vec_znx_dft(rank_in, in_size); - let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) - + module.bytes_of_vec_znx_dft(rank_in, in_size); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - return res_dft + ((ai_dft + vmp) | normalize); - } - - pub fn keyswitch_from_fourier_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) - } - - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) - } - - pub fn automorphism_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) - } - - pub fn automorphism_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) - } - - pub fn external_product_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - ggsw_k: usize, - digits: usize, - rank: usize, - ) -> usize { - let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); - let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); - let out_size: usize = k_out.div_ceil(basek); - let ggsw_size: usize = ggsw_k.div_ceil(basek); - let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) - + module.vmp_apply_tmp_bytes( - out_size, - in_size, - in_size, // rows - rank + 1, // cols in - rank + 1, // cols out - ggsw_size, - ); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - res_dft + (vmp | normalize) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - ggsw_k: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::external_product_scratch_space(module, basek, k_out, k_out, ggsw_k, digits, rank) - } -} - -impl + AsRef<[u8]>> SetMetaData for GLWECiphertext { - fn set_k(&mut self, k: usize) { - self.k = k - } - - fn set_basek(&mut self, basek: usize) { - self.basek = basek - } -} - -impl + AsMut<[u8]>> GLWECiphertext { - pub fn encrypt_sk, DataSk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_sk_private( - module, - Some((pt, 0)), - sk, - source_xa, - source_xe, - sigma, - scratch, - ); - } - - pub fn encrypt_zero_sk>( - &mut self, - module: &Module, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_sk_private( - module, - None::<(&GLWEPlaintext>, usize)>, - sk, - source_xa, - source_xe, - sigma, - scratch, - ); - } - - pub fn encrypt_pk, DataPk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - pk: &GLWEPublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_pk_private( - module, - Some((pt, 0)), - pk, - source_xu, - source_xe, - sigma, - scratch, - ); - } - - pub fn encrypt_zero_pk>( - &mut self, - module: &Module, - pk: &GLWEPublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.encrypt_pk_private( - module, - None::<(&GLWEPlaintext>, usize)>, - pk, - source_xu, - source_xe, - sigma, - scratch, - ); - } - - pub fn automorphism, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - self.keyswitch(module, lhs, &rhs.key, scratch); - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); - }) - } - - pub fn automorphism_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - self.keyswitch_inplace(module, &rhs.key, scratch); - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); - }) - } - - pub fn automorphism_add, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, lhs, &rhs.key, scratch); - } - - pub fn automorphism_add_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); - } - } - - pub fn automorphism_sub_ab, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, lhs, &rhs.key, scratch); - } - - pub fn automorphism_sub_ab_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); - } - } - - pub fn automorphism_sub_ba, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, lhs, &rhs.key, scratch); - } - - pub fn automorphism_sub_ba_inplace>( - &mut self, - module: &Module, - rhs: &AutomorphismKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); - } - } - - pub(crate) fn keyswitch_from_fourier, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertextFourier, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(lhs.rank(), rhs.rank_in()); - assert_eq!(self.rank(), rhs.rank_out()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertext::keyswitch_from_fourier_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ) - ); - } - - let cols_in: usize = rhs.rank_in(); - let cols_out: usize = rhs.rank_out() + 1; - - // Buffer of the result of VMP in DFT - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise - - { - let digits = rhs.digits(); - - (0..digits).for_each(|di| { - // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + di) / digits); - - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft_copy( - digits, - digits - 1 - di, - &mut ai_dft, - col_i, - &lhs.data, - col_i + 1, - ); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); - } - }); - } - - module.vec_znx_dft_add_inplace(&mut res_dft, 0, &lhs.data, 0); - - // Switches result of VMP outside of DFT - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); - - (0..cols_out).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); - }); - } - - pub fn keyswitch, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - Self::keyswitch_private::<_, _, 0>(self, 0, module, lhs, rhs, scratch); - } - - pub(crate) fn keyswitch_private, DataRhs: AsRef<[u8]>, const OP: u8>( - &mut self, - apply_auto: i64, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(lhs.rank(), rhs.rank_in()); - assert_eq!(self.rank(), rhs.rank_out()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertext::keyswitch_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ) - ); - } - - let cols_in: usize = rhs.rank_in(); - let cols_out: usize = rhs.rank_out() + 1; - let digits: usize = rhs.digits(); - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + digits - 1) / digits); - ai_dft.zero(); - { - (0..digits).for_each(|di| { - ai_dft.set_size((lhs.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft( - digits, - digits - di - 1, - &mut ai_dft, - col_i, - &lhs.data, - col_i + 1, - ); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); - } - }); - } - - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - module.vec_znx_big_add_small_inplace(&mut res_big, 0, &lhs.data, 0); - - (0..cols_out).for_each(|i| { - if apply_auto != 0 { - module.vec_znx_big_automorphism_inplace(apply_auto, &mut res_big, i); - } - - match OP { - 1 => module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i), - 2 => module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i), - 3 => module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i), - _ => {} - } - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); - }); - } - - pub fn keyswitch_inplace>( - &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - self.keyswitch(&module, &*self_ptr, rhs, scratch); - } - } - - pub fn external_product, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertext, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(rhs.rank(), lhs.rank()); - assert_eq!(rhs.rank(), self.rank()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertext::external_product_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank(), - ) - ); - } - - let cols: usize = rhs.rank() + 1; - let digits: usize = rhs.digits(); - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); - - { - (0..digits).for_each(|di| { - // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits - a_dft.set_size((lhs.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - - (0..cols).for_each(|col_i| { - module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); - } - }); - } - - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - (0..cols).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); - }); - } - - pub fn external_product_inplace>( - &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - self.external_product(&module, &*self_ptr, rhs, scratch); - } - } - - pub(crate) fn encrypt_sk_private, DataSk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(sk.n(), module.n()); - assert_eq!(self.n(), module.n()); - if let Some((pt, col)) = pt { - assert_eq!(pt.n(), module.n()); - assert!(col < self.rank() + 1); - } - assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), - "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", - scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) - ) - } - - let basek: usize = self.basek(); - let k: usize = self.k(); - let size: usize = self.size(); - let cols: usize = self.rank() + 1; - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size); - c0_big.zero(); - - { - // c[i] = uniform - // c[0] -= c[i] * s[i], - (1..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size); - - // c[i] = uniform - self.data.fill_uniform(basek, i, size, source_xa); - - // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) - module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); - module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); - let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); - - // use c[0] as buffer, which is overwritten later by the normalization step - module.vec_znx_big_normalize(basek, &mut self.data, 0, &ci_big, 0, scratch_2); - - // c0_tmp = -c[i] * s[i] (use c[0] as buffer) - module.vec_znx_sub_ab_inplace(&mut c0_big, 0, &self.data, 0); - - // c[i] += m if col = i - if let Some((pt, col)) = pt { - if i == col { - module.vec_znx_add_inplace(&mut self.data, i, &pt.data, 0); - module.vec_znx_normalize_inplace(basek, &mut self.data, i, scratch_2); - } - } - }); - } - - // c[0] += e - c0_big.add_normal(basek, 0, k, source_xe, sigma, sigma * SIX_SIGMA); - - // c[0] += m if col = 0 - if let Some((pt, col)) = pt { - if col == 0 { - module.vec_znx_add_inplace(&mut c0_big, 0, &pt.data, 0); - } - } - - // c[0] = norm(c[0]) - module.vec_znx_normalize(basek, &mut self.data, 0, &c0_big, 0, scratch_1); - } - - pub(crate) fn encrypt_pk_private, DataPk: AsRef<[u8]>>( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - pk: &GLWEPublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.basek(), pk.basek()); - assert_eq!(self.n(), module.n()); - assert_eq!(pk.n(), module.n()); - assert_eq!(self.rank(), pk.rank()); - if let Some((pt, _)) = pt { - assert_eq!(pt.basek(), pk.basek()); - assert_eq!(pt.n(), module.n()); - } - } - - let basek: usize = pk.basek(); - let size_pk: usize = pk.size(); - let cols: usize = self.rank() + 1; - - // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); - - { - let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); - match pk.dist { - SecretDistribution::NONE => panic!( - "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ - Self::generate" - ), - SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - SecretDistribution::ZERO => {} - } - - module.svp_prepare(&mut u_dft, 0, &u, 0); - } - - // ct[i] = pk[i] * u + ei (+ m if col = i) - (0..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk); - // ci_dft = DFT(u) * DFT(pk[i]) - module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data.data, i); - - // ci_big = u * p[i] - let mut ci_big = module.vec_znx_idft_consume(ci_dft); - - // ci_big = u * pk[i] + e - ci_big.add_normal(basek, 0, pk.k(), source_xe, sigma, sigma * SIX_SIGMA); - - // ci_big = u * pk[i] + e + m (if col = i) - if let Some((pt, col)) = pt { - if col == i { - module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0); - } - } - - // ct[i] = norm(ci_big) - module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2); - }); - } -} - -impl> GLWECiphertext { - pub fn clone(&self) -> GLWECiphertext> { - GLWECiphertext { - data: self.data.clone(), - basek: self.basek(), - k: self.k(), - } - } - - pub fn decrypt + AsRef<[u8]>, DataSk: AsRef<[u8]>>( - &self, - module: &Module, - pt: &mut GLWEPlaintext, - sk: &GLWESecret, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert_eq!(sk.n(), module.n()); - } - - let cols: usize = self.rank() + 1; - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct - c0_big.zero(); - - { - (1..cols).for_each(|i| { - // ci_dft = DFT(a[i]) * DFT(s[i]) - let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); - module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); - let ci_big = module.vec_znx_idft_consume(ci_dft); - - // c0_big += a[i] * s[i] - module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); - }); - } - - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0); - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut c0_big, 0, scratch_1); - - pt.basek = self.basek(); - pt.k = pt.k().min(self.k()); - } -} - -pub trait GLWECiphertextToRef { - fn to_ref(&self) -> GLWECiphertext<&[u8]>; -} - -impl> GLWECiphertextToRef for GLWECiphertext { - fn to_ref(&self) -> GLWECiphertext<&[u8]> { - GLWECiphertext { - data: self.data.to_ref(), - basek: self.basek, - k: self.k, - } - } -} - -pub trait GLWECiphertextToMut { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; -} - -impl + AsRef<[u8]>> GLWECiphertextToMut for GLWECiphertext { - fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext { - data: self.data.to_mut(), - basek: self.basek, - k: self.k, - } - } -} - -impl GLWEOps for GLWECiphertext -where - D: AsRef<[u8]> + AsMut<[u8]>, - GLWECiphertext: GLWECiphertextToMut + Infos + SetMetaData, -{ -} diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs deleted file mode 100644 index 19582f6..0000000 --- a/core/src/glwe_ciphertext_fourier.rs +++ /dev/null @@ -1,321 +0,0 @@ -use backend::{ - Backend, FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, - VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxZero, -}; -use sampling::source::Source; - -use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore}; - -pub struct GLWECiphertextFourier { - pub data: VecZnxDft, - pub basek: usize, - pub k: usize, -} - -impl GLWECiphertextFourier, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { - Self { - data: module.new_vec_znx_dft(rank + 1, k.div_ceil(basek)), - basek: basek, - k: k, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx_dft(rank + 1, k.div_ceil(basek)) - } -} - -impl Infos for GLWECiphertextFourier { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GLWECiphertextFourier { - pub fn rank(&self) -> usize { - self.cols() - 1 - } -} - -impl GLWECiphertextFourier, FFT64> { - #[allow(dead_code)] - pub(crate) fn idft_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - module.bytes_of_vec_znx(1, k.div_ceil(basek)) - + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) - } - - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, k.div_ceil(basek)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - } - - pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = k.div_ceil(basek); - (module.vec_znx_big_normalize_tmp_bytes() - | module.bytes_of_vec_znx_dft(1, size) - | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) - + module.bytes_of_vec_znx_big(1, size) - } - - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - GLWECiphertext::bytes_of(module, basek, k_out, rank_out) - + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) - } - - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) - } - - // WARNING TODO: UPDATE - pub fn external_product_scratch_space( - module: &Module, - basek: usize, - _k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - let ggsw_size: usize = k_ggsw.div_ceil(basek); - let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); - let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); - let ggsw_size: usize = k_ggsw.div_ceil(basek); - let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) - + module.vmp_apply_tmp_bytes(ggsw_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); - let res_small: usize = module.bytes_of_vec_znx(rank + 1, ggsw_size); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - res_dft + (vmp | (res_small + normalize)) - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) - } -} - -impl + AsRef<[u8]>> GLWECiphertextFourier { - pub fn encrypt_zero_sk>( - &mut self, - module: &Module, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); - tmp_ct.encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch1); - tmp_ct.dft(module, self); - } - - pub fn keyswitch, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertextFourier, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); - tmp_ct.keyswitch_from_fourier(module, lhs, rhs, scratch1); - tmp_ct.dft(module, self); - } - - pub fn keyswitch_inplace>( - &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; - self.keyswitch(&module, &*self_ptr, rhs, scratch); - } - } - - pub fn external_product, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWECiphertextFourier, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(rhs.rank(), lhs.rank()); - assert_eq!(rhs.rank(), self.rank()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertextFourier::external_product_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank(), - ) - ); - } - - let cols: usize = rhs.rank() + 1; - let digits = rhs.digits(); - - // Space for VMP result in DFT domain and high precision - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + digits - 1) / digits); - - { - (0..digits).for_each(|di| { - a_dft.set_size((lhs.size() + di) / digits); - - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - - (0..cols).for_each(|col_i| { - module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); - }); - - if di == 0 { - module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); - } else { - module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); - } - }); - } - - // VMP result in high precision - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); - - // Space for VMP result normalized - let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); - (0..cols).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); - module.vec_znx_dft(1, 0, &mut self.data, i, &res_small, i); - }); - } - - pub fn external_product_inplace>( - &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - unsafe { - let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; - self.external_product(&module, &*self_ptr, rhs, scratch); - } - } -} - -impl> GLWECiphertextFourier { - pub fn decrypt + AsMut<[u8]>, DataSk: AsRef<[u8]>>( - &self, - module: &Module, - pt: &mut GLWEPlaintext, - sk: &GLWESecret, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), module.n()); - assert_eq!(pt.n(), module.n()); - assert_eq!(sk.n(), module.n()); - } - - let cols = self.rank() + 1; - - let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct - pt_big.zero(); - - { - (1..cols).for_each(|i| { - let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct - module.svp_apply(&mut ci_dft, 0, &sk.data_fourier, i - 1, &self.data, i); - let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); - module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0); - }); - } - - { - let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size()); - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_idft(&mut c0_big, 0, &self.data, 0, scratch_2); - module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0); - } - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut pt_big, 0, scratch_1); - - pt.basek = self.basek(); - pt.k = pt.k().min(self.k()); - } - - #[allow(dead_code)] - pub(crate) fn idft + AsMut<[u8]>>( - &self, - module: &Module, - res: &mut GLWECiphertext, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), res.rank()); - assert_eq!(self.basek(), res.basek()) - } - - let min_size: usize = self.size().min(res.size()); - - let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size); - - (0..self.rank() + 1).for_each(|i| { - module.vec_znx_idft(&mut res_big, 0, &self.data, i, scratch1); - module.vec_znx_big_normalize(self.basek(), &mut res.data, i, &res_big, 0, scratch1); - }); - } -} diff --git a/core/src/glwe_keys.rs b/core/src/glwe_keys.rs deleted file mode 100644 index 8f04408..0000000 --- a/core/src/glwe_keys.rs +++ /dev/null @@ -1,149 +0,0 @@ -use backend::{ - Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnxDft, - ZnxInfos, ZnxZero, -}; -use sampling::source::Source; - -use crate::{GLWECiphertextFourier, Infos}; - -#[derive(Clone, Copy, Debug)] -pub(crate) enum SecretDistribution { - TernaryFixed(usize), // Ternary with fixed Hamming weight - TernaryProb(f64), // Ternary with probabilistic Hamming weight - ZERO, // Debug mod - NONE, -} - -pub struct GLWESecret { - pub(crate) data: ScalarZnx, - pub(crate) data_fourier: ScalarZnxDft, - pub(crate) dist: SecretDistribution, -} - -impl GLWESecret, B> { - pub fn alloc(module: &Module, rank: usize) -> Self { - Self { - data: module.new_scalar_znx(rank), - data_fourier: module.new_scalar_znx_dft(rank), - dist: SecretDistribution::NONE, - } - } - - pub fn bytes_of(module: &Module, rank: usize) -> usize { - module.bytes_of_scalar_znx(rank) + module.bytes_of_scalar_znx_dft(rank) - } -} - -impl GLWESecret { - pub fn n(&self) -> usize { - self.data.n() - } - - pub fn log_n(&self) -> usize { - self.data.log_n() - } - - pub fn rank(&self) -> usize { - self.data.cols() - } -} - -impl + AsRef<[u8]>> GLWESecret { - pub fn fill_ternary_prob(&mut self, module: &Module, prob: f64, source: &mut Source) { - (0..self.rank()).for_each(|i| { - self.data.fill_ternary_prob(i, prob, source); - }); - self.prep_fourier(module); - self.dist = SecretDistribution::TernaryProb(prob); - } - - pub fn fill_ternary_hw(&mut self, module: &Module, hw: usize, source: &mut Source) { - (0..self.rank()).for_each(|i| { - self.data.fill_ternary_hw(i, hw, source); - }); - self.prep_fourier(module); - self.dist = SecretDistribution::TernaryFixed(hw); - } - - pub fn fill_zero(&mut self) { - self.data.zero(); - self.dist = SecretDistribution::ZERO; - } - - pub(crate) fn prep_fourier(&mut self, module: &Module) { - (0..self.rank()).for_each(|i| { - module.svp_prepare(&mut self.data_fourier, i, &self.data, i); - }); - } -} - -pub struct GLWEPublicKey { - pub(crate) data: GLWECiphertextFourier, - pub(crate) dist: SecretDistribution, -} - -impl GLWEPublicKey, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { - Self { - data: GLWECiphertextFourier::alloc(module, basek, k, rank), - dist: SecretDistribution::NONE, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GLWECiphertextFourier::, B>::bytes_of(module, basek, k, rank) - } -} - -impl Infos for GLWEPublicKey { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data.data - } - - fn basek(&self) -> usize { - self.data.basek - } - - fn k(&self) -> usize { - self.data.k - } -} - -impl GLWEPublicKey { - pub fn rank(&self) -> usize { - self.cols() - 1 - } -} - -impl + AsMut<[u8]>> GLWEPublicKey { - pub fn generate_from_sk>( - &mut self, - module: &Module, - sk: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - ) { - #[cfg(debug_assertions)] - { - match sk.dist { - SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), - _ => {} - } - } - - // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_sk_scratch_space( - module, - self.basek(), - self.k(), - self.rank(), - )); - - self.data - .encrypt_zero_sk(module, sk, source_xa, source_xe, sigma, scratch.borrow()); - self.dist = sk.dist; - } -} diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs deleted file mode 100644 index 56d42b4..0000000 --- a/core/src/keyswitch_key.rs +++ /dev/null @@ -1,343 +0,0 @@ -use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, Scratch, ZnxZero}; -use sampling::source::Source; - -use crate::{GGLWECiphertext, GGSWCiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow}; - -pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); - -impl GLWESwitchingKey, FFT64> { - pub fn alloc( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self { - GLWESwitchingKey(GGLWECiphertext::alloc( - module, basek, k, rows, digits, rank_in, rank_out, - )) - } - - pub fn bytes_of( - module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - GGLWECiphertext::, FFT64>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) - } -} - -impl Infos for GLWESwitchingKey { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - self.0.inner() - } - - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { - self.0.k() - } -} - -impl GLWESwitchingKey { - pub fn rank(&self) -> usize { - self.0.data.cols_out() - 1 - } - - pub fn rank_in(&self) -> usize { - self.0.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.0.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.0.digits() - } -} - -impl> GetRow for GLWESwitchingKey { - fn get_row + AsRef<[u8]>>( - &self, - module: &Module, - row_i: usize, - col_j: usize, - res: &mut GLWECiphertextFourier, - ) { - module.vmp_extract_row(&mut res.data, &self.0.data, row_i, col_j); - } -} - -impl + AsRef<[u8]>> SetRow for GLWESwitchingKey { - fn set_row>( - &mut self, - module: &Module, - row_i: usize, - col_j: usize, - a: &GLWECiphertextFourier, - ) { - module.vmp_prepare_row(&mut self.0.data, row_i, col_j, &a.data); - } -} - -impl GLWESwitchingKey, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - GGLWECiphertext::generate_from_sk_scratch_space(module, basek, k, rank) - } - - pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { - GGLWECiphertext::generate_from_pk_scratch_space(module, _basek, _k, _rank) - } - - pub fn keyswitch_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> usize { - let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank_in); - let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank_out); - let ksk: usize = - GLWECiphertextFourier::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out); - tmp_in + tmp_out + ksk - } - - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize { - let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); - let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); - tmp + ksk - } - - pub fn external_product_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank); - let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); - let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); - tmp_in + tmp_out + ggsw - } - - pub fn external_product_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - ) -> usize { - let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); - let ggsw: usize = - GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); - tmp + ggsw - } -} -impl + AsRef<[u8]>> GLWESwitchingKey { - pub fn generate_from_sk, DataSkOut: AsRef<[u8]>>( - &mut self, - module: &Module, - sk_in: &GLWESecret, - sk_out: &GLWESecret, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, - ) { - self.0.encrypt_sk( - module, - &sk_in.data, - sk_out, - source_xa, - source_xe, - sigma, - scratch, - ); - } - - pub fn keyswitch, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWESwitchingKey, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - lhs.rank_out(), - rhs.rank_in(), - "ksk_in output rank: {} != ksk_apply input rank: {}", - self.rank_out(), - rhs.rank_in() - ); - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - } - - let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_in); - tmp_out.keyswitch(module, &tmp_in, rhs, scratch2); - self.set_row(module, row_j, col_i, &tmp_out); - }); - }); - - tmp_out.data.zero(); - - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank_in()).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_out); - }); - }); - } - - pub fn keyswitch_inplace>( - &mut self, - module: &Module, - rhs: &GLWESwitchingKey, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_out(), - rhs.rank_out(), - "ksk_out output rank: {} != ksk_apply output rank: {}", - self.rank_out(), - rhs.rank_out() - ); - } - - let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - self.get_row(module, row_j, col_i, &mut tmp); - tmp.keyswitch_inplace(module, rhs, scratch1); - self.set_row(module, row_j, col_i, &tmp); - }); - }); - } - - pub fn external_product, DataRhs: AsRef<[u8]>>( - &mut self, - module: &Module, - lhs: &GLWESwitchingKey, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_in(), - lhs.rank_in(), - "ksk_out input rank: {} != ksk_in input rank: {}", - self.rank_in(), - lhs.rank_in() - ); - assert_eq!( - lhs.rank_out(), - rhs.rank(), - "ksk_in output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - assert_eq!( - self.rank_out(), - rhs.rank(), - "ksk_out output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - } - - let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_in); - tmp_out.external_product(module, &tmp_in, rhs, scratch2); - self.set_row(module, row_j, col_i, &tmp_out); - }); - }); - - tmp_out.data.zero(); - - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank_in()).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_out); - }); - }); - } - - pub fn external_product_inplace>( - &mut self, - module: &Module, - rhs: &GGSWCiphertext, - scratch: &mut Scratch, - ) { - #[cfg(debug_assertions)] - { - assert_eq!( - self.rank_out(), - rhs.rank(), - "ksk_out output rank: {} != ggsw rank: {}", - self.rank_out(), - rhs.rank() - ); - } - - let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - println!("tmp: {}", tmp.size()); - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - self.get_row(module, row_j, col_i, &mut tmp); - tmp.external_product_inplace(module, rhs, scratch1); - self.set_row(module, row_j, col_i, &tmp); - }); - }); - } -} diff --git a/core/src/lib.rs b/core/src/lib.rs index 69eb045..9f8c114 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,42 +1,43 @@ -pub mod automorphism; +pub mod blind_rotation; +pub mod dist; pub mod elem; -pub mod gglwe_ciphertext; -pub mod ggsw_ciphertext; -pub mod glwe_ciphertext; -pub mod glwe_ciphertext_fourier; -pub mod glwe_keys; -pub mod glwe_ops; -pub mod glwe_packing; -pub mod glwe_plaintext; -pub mod keyswitch_key; -pub mod tensor_key; -#[cfg(test)] -mod test_fft64; -pub mod trace; +pub mod fourier_glwe; +pub mod gglwe; +pub mod ggsw; +pub mod glwe; +pub mod lwe; +pub mod noise; -pub use automorphism::*; use backend::Backend; use backend::FFT64; use backend::Module; -pub use elem::*; -pub use gglwe_ciphertext::*; -pub use ggsw_ciphertext::*; -pub use glwe_ciphertext::*; -pub use glwe_ciphertext_fourier::*; -pub use glwe_keys::*; -pub use glwe_ops::*; -pub use glwe_packing::*; -pub use glwe_plaintext::*; -pub use keyswitch_key::*; -pub use tensor_key::*; +pub use blind_rotation::{BlindRotationKeyCGGI, LookUpTable, cggi_blind_rotate, cggi_blind_rotate_scratch_space}; +pub use elem::{GetRow, Infos, SetMetaData, SetRow}; +pub use fourier_glwe::{FourierGLWECiphertext, FourierGLWESecret}; +pub use gglwe::{GGLWECiphertext, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey}; +pub use ggsw::GGSWCiphertext; +pub use glwe::{GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWEPublicKey, GLWESecret}; +pub use lwe::{LWECiphertext, LWESecret}; + +pub(crate) use glwe::{GLWECiphertextToMut, GLWECiphertextToRef}; pub use backend::Scratch; pub use backend::ScratchOwned; +use crate::dist::Distribution; + pub(crate) const SIX_SIGMA: f64 = 6.0; pub trait ScratchCore { fn tmp_glwe_ct(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self); + fn tmp_vec_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self); fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self); fn tmp_gglwe( &mut self, @@ -57,14 +58,23 @@ pub trait ScratchCore { digits: usize, rank: usize, ) -> (GGSWCiphertext<&mut [u8], B>, &mut Self); - fn tmp_glwe_fourier( + fn tmp_fourier_glwe_ct( &mut self, module: &Module, basek: usize, k: usize, rank: usize, - ) -> (GLWECiphertextFourier<&mut [u8], B>, &mut Self); - fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8], B>, &mut Self); + ) -> (FourierGLWECiphertext<&mut [u8], B>, &mut Self); + fn tmp_slice_fourier_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self); + fn tmp_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); + fn tmp_fourier_glwe_secret(&mut self, module: &Module, rank: usize) -> (FourierGLWESecret<&mut [u8], B>, &mut Self); fn tmp_glwe_pk( &mut self, module: &Module, @@ -90,7 +100,7 @@ pub trait ScratchCore { rows: usize, digits: usize, rank: usize, - ) -> (TensorKey<&mut [u8], B>, &mut Self); + ) -> (GLWETensorKey<&mut [u8], B>, &mut Self); fn tmp_autokey( &mut self, module: &Module, @@ -99,7 +109,7 @@ pub trait ScratchCore { rows: usize, digits: usize, rank: usize, - ) -> (AutomorphismKey<&mut [u8], B>, &mut Self); + ) -> (GLWEAutomorphismKey<&mut [u8], B>, &mut Self); } impl ScratchCore for Scratch { @@ -114,6 +124,24 @@ impl ScratchCore for Scratch { (GLWECiphertext { data, basek, k }, scratch) } + fn tmp_vec_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut cts: Vec> = Vec::with_capacity(size); + for _ in 0..size { + let (ct, new_scratch) = scratch.tmp_glwe_ct(module, basek, k, rank); + scratch = new_scratch; + cts.push(ct); + } + (cts, scratch) + } + fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { let (data, scratch) = self.tmp_vec_znx(module, 1, k.div_ceil(basek)); (GLWEPlaintext { data, basek, k }, scratch) @@ -174,15 +202,33 @@ impl ScratchCore for Scratch { ) } - fn tmp_glwe_fourier( + fn tmp_fourier_glwe_ct( &mut self, module: &Module, basek: usize, k: usize, rank: usize, - ) -> (GLWECiphertextFourier<&mut [u8], FFT64>, &mut Self) { + ) -> (FourierGLWECiphertext<&mut [u8], FFT64>, &mut Self) { let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, k.div_ceil(basek)); - (GLWECiphertextFourier { data, basek, k }, scratch) + (FourierGLWECiphertext { data, basek, k }, scratch) + } + + fn tmp_slice_fourier_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut cts: Vec> = Vec::with_capacity(size); + for _ in 0..size { + let (ct, new_scratch) = scratch.tmp_fourier_glwe_ct(module, basek, k, rank); + scratch = new_scratch; + cts.push(ct); + } + (cts, scratch) } fn tmp_glwe_pk( @@ -192,26 +238,39 @@ impl ScratchCore for Scratch { k: usize, rank: usize, ) -> (GLWEPublicKey<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_glwe_fourier(module, basek, k, rank); + let (data, scratch) = self.tmp_fourier_glwe_ct(module, basek, k, rank); ( GLWEPublicKey { data, - dist: SecretDistribution::NONE, + dist: Distribution::NONE, }, scratch, ) } - fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8], FFT64>, &mut Self) { + fn tmp_glwe_secret(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self) { let (data, scratch) = self.tmp_scalar_znx(module, rank); - let (data_fourier, scratch1) = scratch.tmp_scalar_znx_dft(module, rank); ( GLWESecret { data, - data_fourier, - dist: SecretDistribution::NONE, + dist: Distribution::NONE, }, - scratch1, + scratch, + ) + } + + fn tmp_fourier_glwe_secret( + &mut self, + module: &Module, + rank: usize, + ) -> (FourierGLWESecret<&mut [u8], FFT64>, &mut Self) { + let (data, scratch) = self.tmp_scalar_znx_dft(module, rank); + ( + FourierGLWESecret { + data, + dist: Distribution::NONE, + }, + scratch, ) } @@ -226,7 +285,14 @@ impl ScratchCore for Scratch { rank_out: usize, ) -> (GLWESwitchingKey<&mut [u8], FFT64>, &mut Self) { let (data, scratch) = self.tmp_gglwe(module, basek, k, rows, digits, rank_in, rank_out); - (GLWESwitchingKey(data), scratch) + ( + GLWESwitchingKey { + key: data, + sk_in_n: 0, + sk_out_n: 0, + }, + scratch, + ) } fn tmp_autokey( @@ -237,9 +303,9 @@ impl ScratchCore for Scratch { rows: usize, digits: usize, rank: usize, - ) -> (AutomorphismKey<&mut [u8], FFT64>, &mut Self) { + ) -> (GLWEAutomorphismKey<&mut [u8], FFT64>, &mut Self) { let (data, scratch) = self.tmp_glwe_ksk(module, basek, k, rows, digits, rank, rank); - (AutomorphismKey { key: data, p: 0 }, scratch) + (GLWEAutomorphismKey { key: data, p: 0 }, scratch) } fn tmp_tsk( @@ -250,7 +316,7 @@ impl ScratchCore for Scratch { rows: usize, digits: usize, rank: usize, - ) -> (TensorKey<&mut [u8], FFT64>, &mut Self) { + ) -> (GLWETensorKey<&mut [u8], FFT64>, &mut Self) { let mut keys: Vec> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); @@ -266,6 +332,6 @@ impl ScratchCore for Scratch { scratch = s; keys.push(gglwe); } - (TensorKey { keys }, scratch) + (GLWETensorKey { keys }, scratch) } } diff --git a/core/src/lwe/ciphertext.rs b/core/src/lwe/ciphertext.rs new file mode 100644 index 0000000..1e97eb4 --- /dev/null +++ b/core/src/lwe/ciphertext.rs @@ -0,0 +1,77 @@ +use backend::{VecZnx, VecZnxToMut, VecZnxToRef}; + +use crate::{Infos, SetMetaData}; + +pub struct LWECiphertext { + pub(crate) data: VecZnx, + pub(crate) k: usize, + pub(crate) basek: usize, +} + +impl LWECiphertext> { + pub fn alloc(n: usize, basek: usize, k: usize) -> Self { + Self { + data: VecZnx::new::(n + 1, 1, k.div_ceil(basek)), + k: k, + basek: basek, + } + } +} + +impl Infos for LWECiphertext { + type Inner = VecZnx; + + fn n(&self) -> usize { + &self.inner().n - 1 + } + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl + AsRef<[u8]>> SetMetaData for LWECiphertext { + fn set_k(&mut self, k: usize) { + self.k = k + } + + fn set_basek(&mut self, basek: usize) { + self.basek = basek + } +} + +pub trait LWECiphertextToRef { + fn to_ref(&self) -> LWECiphertext<&[u8]>; +} + +impl> LWECiphertextToRef for LWECiphertext { + fn to_ref(&self) -> LWECiphertext<&[u8]> { + LWECiphertext { + data: self.data.to_ref(), + basek: self.basek, + k: self.k, + } + } +} + +pub trait LWECiphertextToMut { + fn to_mut(&mut self) -> LWECiphertext<&mut [u8]>; +} + +impl + AsRef<[u8]>> LWECiphertextToMut for LWECiphertext { + fn to_mut(&mut self) -> LWECiphertext<&mut [u8]> { + LWECiphertext { + data: self.data.to_mut(), + basek: self.basek, + k: self.k, + } + } +} diff --git a/core/src/lwe/decryption.rs b/core/src/lwe/decryption.rs new file mode 100644 index 0000000..3ed9d2b --- /dev/null +++ b/core/src/lwe/decryption.rs @@ -0,0 +1,34 @@ +use backend::{ZnxView, ZnxViewMut, alloc_aligned}; + +use crate::{Infos, LWECiphertext, LWESecret, SetMetaData, lwe::LWEPlaintext}; + +impl LWECiphertext +where + DataSelf: AsRef<[u8]>, +{ + pub fn decrypt(&self, pt: &mut LWEPlaintext, sk: &LWESecret) + where + DataPt: AsRef<[u8]> + AsMut<[u8]>, + DataSk: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), sk.n()); + } + + (0..pt.size().min(self.size())).for_each(|i| { + pt.data.at_mut(0, i)[0] = self.data.at(0, i)[0] + + self.data.at(0, i)[1..] + .iter() + .zip(sk.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::(); + }); + + let mut tmp_bytes: Vec = alloc_aligned(size_of::()); + pt.data.normalize(self.basek(), 0, &mut tmp_bytes); + + pt.set_basek(self.basek()); + pt.set_k(self.k().min(pt.size() * self.basek())); + } +} diff --git a/core/src/lwe/encryption.rs b/core/src/lwe/encryption.rs new file mode 100644 index 0000000..00d814f --- /dev/null +++ b/core/src/lwe/encryption.rs @@ -0,0 +1,60 @@ +use backend::{AddNormal, FillUniform, VecZnx, ZnxView, ZnxViewMut, alloc_aligned}; +use sampling::source::Source; + +use crate::{Infos, LWECiphertext, LWESecret, SIX_SIGMA, lwe::LWEPlaintext}; + +impl LWECiphertext +where + DataSelf: AsMut<[u8]> + AsRef<[u8]>, +{ + pub fn encrypt_sk( + &mut self, + pt: &LWEPlaintext, + sk: &LWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + ) where + DataPt: AsRef<[u8]>, + DataSk: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), sk.n()) + } + + let basek: usize = self.basek(); + + self.data.fill_uniform(basek, 0, self.size(), source_xa); + let mut tmp_znx: VecZnx> = VecZnx::>::new::(1, 1, self.size()); + + let min_size = self.size().min(pt.size()); + + (0..min_size).for_each(|i| { + tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] + - self.data.at(0, i)[1..] + .iter() + .zip(sk.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::(); + }); + + (min_size..self.size()).for_each(|i| { + tmp_znx.at_mut(0, i)[0] -= self.data.at(0, i)[1..] + .iter() + .zip(sk.data.at(0, 0)) + .map(|(x, y)| x * y) + .sum::(); + }); + + tmp_znx.add_normal(basek, 0, self.k(), source_xe, sigma, sigma * SIX_SIGMA); + + let mut tmp_bytes: Vec = alloc_aligned(size_of::()); + + tmp_znx.normalize(basek, 0, &mut tmp_bytes); + + (0..self.size()).for_each(|i| { + self.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; + }); + } +} diff --git a/core/src/lwe/keyswtich.rs b/core/src/lwe/keyswtich.rs new file mode 100644 index 0000000..d06b7aa --- /dev/null +++ b/core/src/lwe/keyswtich.rs @@ -0,0 +1,313 @@ +use backend::{Backend, FFT64, Module, Scratch, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero}; +use sampling::source::Source; + +use crate::{FourierGLWESecret, GLWECiphertext, GLWESecret, GLWESwitchingKey, Infos, LWECiphertext, LWESecret, ScratchCore}; + +/// A special [GLWESwitchingKey] required to for the conversion from [GLWECiphertext] to [LWECiphertext]. +pub struct GLWEToLWESwitchingKey(GLWESwitchingKey); + +impl GLWEToLWESwitchingKey, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank, 1)) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + FourierGLWESecret::bytes_of(module, rank) + + (GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) | GLWESecret::bytes_of(module, rank)) + } +} + +impl + AsRef<[u8]>> GLWEToLWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_lwe: &LWESecret, + sk_glwe: &GLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DLwe: AsRef<[u8]>, + DGlwe: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(sk_lwe.n() <= module.n()); + } + + let (mut sk_lwe_as_glwe_dft, scratch1) = scratch.tmp_fourier_glwe_secret(module, 1); + + { + let (mut sk_lwe_as_glwe, _) = scratch1.tmp_glwe_secret(module, 1); + sk_lwe_as_glwe.data.zero(); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); + module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data, 0); + sk_lwe_as_glwe_dft.set(module, &sk_lwe_as_glwe); + } + + self.0.encrypt_sk( + module, + sk_glwe, + &sk_lwe_as_glwe_dft, + source_xa, + source_xe, + sigma, + scratch1, + ); + } +} + +/// A special [GLWESwitchingKey] required to for the conversion from [LWECiphertext] to [GLWECiphertext]. +pub struct LWEToGLWESwitchingKey(GLWESwitchingKey); + +impl LWEToGLWESwitchingKey, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, 1, rank)) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, rank) + GLWESecret::bytes_of(module, 1) + } +} + +impl + AsRef<[u8]>> LWEToGLWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_lwe: &LWESecret, + sk_glwe: &FourierGLWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DLwe: AsRef<[u8]>, + DGlwe: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(sk_lwe.n() <= module.n()); + } + + let (mut sk_lwe_as_glwe, scratch1) = scratch.tmp_glwe_secret(module, 1); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); + sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data, 0); + + self.0.encrypt_sk( + module, + &sk_lwe_as_glwe, + &sk_glwe, + source_xa, + source_xe, + sigma, + scratch1, + ); + } +} + +pub struct LWESwitchingKey(GLWESwitchingKey); + +impl LWESwitchingKey, FFT64> { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize) -> Self { + Self(GLWESwitchingKey::alloc(module, basek, k, rows, 1, 1, 1)) + } + + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { + GLWESecret::bytes_of(module, 1) + + FourierGLWESecret::bytes_of(module, 1) + + GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, 1) + } +} + +impl + AsRef<[u8]>> LWESwitchingKey { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_lwe_in: &LWESecret, + sk_lwe_out: &LWESecret, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + DIn: AsRef<[u8]>, + DOut: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(sk_lwe_in.n() <= module.n()); + assert!(sk_lwe_out.n() <= module.n()); + } + + let (mut sk_in_glwe, scratch1) = scratch.tmp_glwe_secret(module, 1); + let (mut sk_out_glwe, scratch2) = scratch1.tmp_fourier_glwe_secret(module, 1); + + sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_out.n()].copy_from_slice(sk_lwe_out.data.at(0, 0)); + sk_in_glwe.data.at_mut(0, 0)[sk_lwe_out.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data, 0); + sk_out_glwe.set(module, &sk_in_glwe); + sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n()].copy_from_slice(sk_lwe_in.data.at(0, 0)); + sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n()..].fill(0); + module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data, 0); + + self.0.encrypt_sk( + module, + &sk_in_glwe, + &sk_out_glwe, + source_xa, + source_xe, + sigma, + scratch2, + ); + } +} + +impl LWECiphertext> { + pub fn from_glwe_scratch_space( + module: &Module, + basek: usize, + k_lwe: usize, + k_glwe: usize, + k_ksk: usize, + rank: usize, + ) -> usize { + GLWECiphertext::bytes_of(module, basek, k_lwe, 1) + + GLWECiphertext::keyswitch_scratch_space(module, basek, k_lwe, k_glwe, k_ksk, 1, rank, 1) + } + + pub fn keyswitch_scratch_space( + module: &Module, + basek: usize, + k_lwe_out: usize, + k_lwe_in: usize, + k_ksk: usize, + ) -> usize { + GLWECiphertext::bytes_of(module, basek, k_lwe_out.max(k_lwe_in), 1) + + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1) + } +} + +impl + AsMut<[u8]>> LWECiphertext { + pub fn sample_extract(&mut self, a: &GLWECiphertext) + where + DGlwe: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(self.n() <= a.n()); + } + + let min_size: usize = self.size().min(a.size()); + let n: usize = self.n(); + + self.data.zero(); + (0..min_size).for_each(|i| { + let data_lwe: &mut [i64] = self.data.at_mut(0, i); + data_lwe[0] = a.data.at(0, i)[0]; + data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]); + }); + } + + pub fn from_glwe( + &mut self, + module: &Module, + a: &GLWECiphertext, + ks: &GLWEToLWESwitchingKey, + scratch: &mut Scratch, + ) where + DGlwe: AsRef<[u8]>, + DKs: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.basek(), a.basek()); + } + let (mut tmp_glwe, scratch1) = scratch.tmp_glwe_ct(module, a.basek(), self.k(), 1); + tmp_glwe.keyswitch(module, a, &ks.0, scratch1); + self.sample_extract(&tmp_glwe); + } + + pub fn keyswitch( + &mut self, + module: &Module, + a: &LWECiphertext, + ksk: &LWESwitchingKey, + scratch: &mut Scratch, + ) where + A: AsRef<[u8]>, + DKs: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(self.n() <= module.n()); + assert!(a.n() <= module.n()); + assert_eq!(self.basek(), a.basek()); + } + + let max_k: usize = self.k().max(a.k()); + let basek: usize = self.basek(); + + let (mut glwe, scratch1) = scratch.tmp_glwe_ct(&module, basek, max_k, 1); + glwe.data.zero(); + + let n_lwe: usize = a.n(); + + (0..a.size()).for_each(|i| { + let data_lwe: &[i64] = a.data.at(0, i); + glwe.data.at_mut(0, i)[0] = data_lwe[0]; + glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + }); + + glwe.keyswitch_inplace(module, &ksk.0, scratch1); + + self.sample_extract(&glwe); + } +} + +impl GLWECiphertext> { + pub fn from_lwe_scratch_space( + module: &Module, + basek: usize, + k_lwe: usize, + k_glwe: usize, + k_ksk: usize, + rank: usize, + ) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, basek, k_glwe, k_lwe, k_ksk, 1, 1, rank) + + GLWECiphertext::bytes_of(module, basek, k_lwe, 1) + } +} + +impl + AsMut<[u8]>> GLWECiphertext { + pub fn from_lwe( + &mut self, + module: &Module, + lwe: &LWECiphertext, + ksk: &LWEToGLWESwitchingKey, + scratch: &mut Scratch, + ) where + DLwe: AsRef<[u8]>, + DKsk: AsRef<[u8]>, + { + #[cfg(debug_assertions)] + { + assert!(lwe.n() <= self.n()); + assert_eq!(self.basek(), self.basek()); + } + + let (mut glwe, scratch1) = scratch.tmp_glwe_ct(module, lwe.basek(), lwe.k(), 1); + glwe.data.zero(); + + let n_lwe: usize = lwe.n(); + + (0..lwe.size()).for_each(|i| { + let data_lwe: &[i64] = lwe.data.at(0, i); + glwe.data.at_mut(0, i)[0] = data_lwe[0]; + glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + }); + + self.keyswitch(module, &glwe, &ksk.0, scratch1); + } +} diff --git a/core/src/lwe/mod.rs b/core/src/lwe/mod.rs new file mode 100644 index 0000000..1e3d351 --- /dev/null +++ b/core/src/lwe/mod.rs @@ -0,0 +1,13 @@ +pub mod ciphertext; +pub mod decryption; +pub mod encryption; +pub mod keyswtich; +pub mod plaintext; +pub mod secret; + +pub use ciphertext::LWECiphertext; +pub use plaintext::LWEPlaintext; +pub use secret::LWESecret; + +#[cfg(test)] +pub mod test_fft64; diff --git a/core/src/lwe/plaintext.rs b/core/src/lwe/plaintext.rs new file mode 100644 index 0000000..7c73351 --- /dev/null +++ b/core/src/lwe/plaintext.rs @@ -0,0 +1,73 @@ +use backend::{VecZnx, VecZnxToMut, VecZnxToRef}; + +use crate::{Infos, SetMetaData}; + +pub struct LWEPlaintext { + pub(crate) data: VecZnx, + pub(crate) k: usize, + pub(crate) basek: usize, +} + +impl LWEPlaintext> { + pub fn alloc(basek: usize, k: usize) -> Self { + Self { + data: VecZnx::new::(1, 1, k.div_ceil(basek)), + k: k, + basek: basek, + } + } +} + +impl Infos for LWEPlaintext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl + AsRef<[u8]>> SetMetaData for LWEPlaintext { + fn set_k(&mut self, k: usize) { + self.k = k + } + + fn set_basek(&mut self, basek: usize) { + self.basek = basek + } +} + +pub trait LWEPlaintextToRef { + fn to_ref(&self) -> LWEPlaintext<&[u8]>; +} + +impl> LWEPlaintextToRef for LWEPlaintext { + fn to_ref(&self) -> LWEPlaintext<&[u8]> { + LWEPlaintext { + data: self.data.to_ref(), + basek: self.basek, + k: self.k, + } + } +} + +pub trait LWEPlaintextToMut { + fn to_mut(&mut self) -> LWEPlaintext<&mut [u8]>; +} + +impl + AsRef<[u8]>> LWEPlaintextToMut for LWEPlaintext { + fn to_mut(&mut self) -> LWEPlaintext<&mut [u8]> { + LWEPlaintext { + data: self.data.to_mut(), + basek: self.basek, + k: self.k, + } + } +} diff --git a/core/src/lwe/secret.rs b/core/src/lwe/secret.rs new file mode 100644 index 0000000..90776a7 --- /dev/null +++ b/core/src/lwe/secret.rs @@ -0,0 +1,64 @@ +use backend::{ScalarZnx, ZnxInfos, ZnxZero}; +use sampling::source::Source; + +use crate::Distribution; + +pub struct LWESecret { + pub(crate) data: ScalarZnx, + pub(crate) dist: Distribution, +} + +impl LWESecret> { + pub fn alloc(n: usize) -> Self { + Self { + data: ScalarZnx::new(n, 1), + dist: Distribution::NONE, + } + } +} + +impl LWESecret { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl + AsMut<[u8]>> LWESecret { + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { + self.data.fill_ternary_prob(0, prob, source); + self.dist = Distribution::TernaryProb(prob); + } + + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + self.data.fill_ternary_hw(0, hw, source); + self.dist = Distribution::TernaryFixed(hw); + } + + pub fn fill_binary_prob(&mut self, prob: f64, source: &mut Source) { + self.data.fill_binary_prob(0, prob, source); + self.dist = Distribution::BinaryProb(prob); + } + + pub fn fill_binary_hw(&mut self, hw: usize, source: &mut Source) { + self.data.fill_binary_hw(0, hw, source); + self.dist = Distribution::BinaryFixed(hw); + } + + pub fn fill_binary_block(&mut self, block_size: usize, source: &mut Source) { + self.data.fill_binary_block(0, block_size, source); + self.dist = Distribution::BinaryBlock(block_size); + } + + pub fn fill_zero(&mut self) { + self.data.zero(); + self.dist = Distribution::ZERO; + } +} diff --git a/core/src/lwe/test_fft64/conversion.rs b/core/src/lwe/test_fft64/conversion.rs new file mode 100644 index 0000000..1fbd4cb --- /dev/null +++ b/core/src/lwe/test_fft64/conversion.rs @@ -0,0 +1,220 @@ +use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView}; +use sampling::source::Source; + +use crate::{ + FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, + lwe::{ + LWEPlaintext, + keyswtich::{GLWEToLWESwitchingKey, LWESwitchingKey, LWEToGLWESwitchingKey}, + }, +}; + +#[test] +fn lwe_to_glwe() { + let n: usize = 1 << 5; + let module: Module = Module::::new(n); + let basek: usize = 17; + let sigma: f64 = 3.2; + + let rank: usize = 2; + + let n_lwe: usize = 22; + let k_lwe_ct: usize = 2 * basek; + let k_lwe_pt: usize = 8; + + let k_glwe_ct: usize = 3 * basek; + + let k_ksk: usize = k_lwe_ct + basek; + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWECiphertext::from_lwe_scratch_space(&module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, k_glwe_ct), + ); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::alloc(&module, rank); + sk_glwe_dft.set(&module, &sk_glwe); + + let mut sk_lwe = LWESecret::alloc(n_lwe); + sk_lwe.fill_ternary_prob(0.5, &mut source_xs); + + let data: i64 = 17; + + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); + lwe_pt + .data + .encode_coeff_i64(0, basek, k_lwe_pt, 0, data, k_lwe_pt); + + let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + lwe_ct.encrypt_sk(&lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe, sigma); + + let mut ksk: LWEToGLWESwitchingKey, FFT64> = LWEToGLWESwitchingKey::alloc(&module, basek, k_ksk, lwe_ct.size(), rank); + + ksk.encrypt_sk( + &module, + &sk_lwe, + &sk_glwe_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_glwe_ct, rank); + glwe_ct.from_lwe(&module, &lwe_ct, &ksk, scratch.borrow()); + + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_glwe_ct); + glwe_ct.decrypt(&module, &mut glwe_pt, &sk_glwe_dft, scratch.borrow()); + + assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); +} + +#[test] +fn glwe_to_lwe() { + let n: usize = 1 << 5; + let module: Module = Module::::new(n); + let basek: usize = 17; + let sigma: f64 = 3.2; + + let rank: usize = 2; + + let n_lwe: usize = 22; + let k_lwe_ct: usize = 2 * basek; + let k_lwe_pt: usize = 8; + + let k_glwe_ct: usize = 3 * basek; + + let k_ksk: usize = k_lwe_ct + basek; + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | GLWECiphertext::from_lwe_scratch_space(&module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(&module, basek, k_glwe_ct), + ); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_glwe_dft: FourierGLWESecret, FFT64> = FourierGLWESecret::alloc(&module, rank); + sk_glwe_dft.set(&module, &sk_glwe); + + let mut sk_lwe = LWESecret::alloc(n_lwe); + sk_lwe.fill_ternary_prob(0.5, &mut source_xs); + + let data: i64 = 17; + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_glwe_ct); + + glwe_pt + .data + .encode_coeff_i64(0, basek, k_lwe_pt, 0, data, k_lwe_pt); + + let mut glwe_ct = GLWECiphertext::alloc(&module, basek, k_glwe_ct, rank); + glwe_ct.encrypt_sk( + &module, + &glwe_pt, + &sk_glwe_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ksk: GLWEToLWESwitchingKey, FFT64> = + GLWEToLWESwitchingKey::alloc(&module, basek, k_ksk, glwe_ct.size(), rank); + + ksk.encrypt_sk( + &module, + &sk_lwe, + &sk_glwe, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + lwe_ct.from_glwe(&module, &glwe_ct, &ksk, scratch.borrow()); + + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_ct); + lwe_ct.decrypt(&mut lwe_pt, &sk_lwe); + + assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); +} + +#[test] +fn keyswitch() { + let n: usize = 1 << 5; + let module: Module = Module::::new(n); + let basek: usize = 17; + let sigma: f64 = 3.2; + + let n_lwe_in: usize = 22; + let n_lwe_out: usize = 30; + let k_lwe_ct: usize = 2 * basek; + let k_lwe_pt: usize = 8; + + let k_ksk: usize = k_lwe_ct + basek; + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + LWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk) + | LWECiphertext::keyswitch_scratch_space(&module, basek, k_lwe_ct, k_lwe_ct, k_ksk), + ); + + let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in); + sk_lwe_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_lwe_out: LWESecret> = LWESecret::alloc(n_lwe_out); + sk_lwe_out.fill_ternary_prob(0.5, &mut source_xs); + + let data: i64 = 17; + + let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); + lwe_pt_in + .data + .encode_coeff_i64(0, basek, k_lwe_pt, 0, data, k_lwe_pt); + + let mut lwe_ct_in: LWECiphertext> = LWECiphertext::alloc(n_lwe_in, basek, k_lwe_ct); + lwe_ct_in.encrypt_sk( + &lwe_pt_in, + &sk_lwe_in, + &mut source_xa, + &mut source_xe, + sigma, + ); + + let mut ksk: LWESwitchingKey, FFT64> = LWESwitchingKey::alloc(&module, basek, k_ksk, lwe_ct_in.size()); + + ksk.encrypt_sk( + &module, + &sk_lwe_in, + &sk_lwe_out, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut lwe_ct_out: LWECiphertext> = LWECiphertext::alloc(n_lwe_out, basek, k_lwe_ct); + + lwe_ct_out.keyswitch(&module, &lwe_ct_in, &ksk, scratch.borrow()); + + let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_ct); + lwe_ct_out.decrypt(&mut lwe_pt_out, &sk_lwe_out); + + assert_eq!(lwe_pt_in.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); +} diff --git a/core/src/lwe/test_fft64/mod.rs b/core/src/lwe/test_fft64/mod.rs new file mode 100644 index 0000000..11eb2fc --- /dev/null +++ b/core/src/lwe/test_fft64/mod.rs @@ -0,0 +1 @@ +pub mod conversion; diff --git a/core/src/test_fft64/mod.rs b/core/src/noise.rs similarity index 96% rename from core/src/test_fft64/mod.rs rename to core/src/noise.rs index 73a58e9..cfc7698 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/noise.rs @@ -1,12 +1,4 @@ -mod automorphism_key; -mod gglwe; -mod ggsw; -mod glwe; -mod glwe_fourier; -mod glwe_packing; -mod tensor_key; -mod trace; - +#[allow(dead_code)] pub(crate) fn var_noise_gglwe_product( n: f64, basek: usize, @@ -37,6 +29,7 @@ pub(crate) fn var_noise_gglwe_product( noise } +#[allow(dead_code)] pub(crate) fn log2_std_noise_gglwe_product( n: f64, basek: usize, @@ -65,6 +58,7 @@ pub(crate) fn log2_std_noise_gglwe_product( noise.log2().min(-1.0).max(-(a_logq as f64)) // max noise is [-2^{-1}, 2^{-1}] } +#[allow(dead_code)] pub(crate) fn noise_ggsw_product( n: f64, basek: usize, @@ -97,6 +91,7 @@ pub(crate) fn noise_ggsw_product( noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] } +#[allow(dead_code)] pub(crate) fn noise_ggsw_keyswitch( n: f64, basek: usize, diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs deleted file mode 100644 index 4859ad1..0000000 --- a/core/src/test_fft64/glwe.rs +++ /dev/null @@ -1,856 +0,0 @@ -use backend::{ - Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut, - ZnxZero, -}; -use itertools::izip; -use sampling::source::Source; - -use crate::{ - GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, - automorphism::AutomorphismKey, - keyswitch_key::GLWESwitchingKey, - test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, -}; - -#[test] -fn encrypt_sk() { - let log_n: usize = 8; - (1..4).for_each(|rank| { - println!("test encrypt_sk rank: {}", rank); - test_encrypt_sk(log_n, 8, 54, 30, 3.2, rank); - }); -} - -#[test] -fn encrypt_zero_sk() { - let log_n: usize = 8; - (1..4).for_each(|rank| { - println!("test encrypt_zero_sk rank: {}", rank); - test_encrypt_zero_sk(log_n, 8, 64, 3.2, rank); - }); -} - -#[test] -fn encrypt_pk() { - let log_n: usize = 8; - (1..4).for_each(|rank| { - println!("test encrypt_pk rank: {}", rank); - test_encrypt_pk(log_n, 8, 64, 64, 3.2, rank) - }); -} - -#[test] -fn keyswitch() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_out: usize = k_ksk; // better capture noise - println!( - "test keyswitch digits: {} rank_in: {} rank_out: {}", - di, rank_in, rank_out - ); - test_keyswitch(log_n, basek, k_out, k_in, k_ksk, di, rank_in, rank_out, 3.2); - }) - }); - }); -} - -#[test] -fn keyswitch_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 45; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!("test keyswitch_inplace digits: {} rank: {}", di, rank); - test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); - }); - }); -} - -#[test] -fn external_product() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; - let k_out: usize = k_ggsw; // Better capture noise - println!("test external_product digits: {} rank: {}", di, rank); - test_external_product(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); - }); - }); -} - -#[test] -fn external_product_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; - println!("test external_product digits: {} rank: {}", di, rank); - test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); - }); - }); -} - -#[test] -fn automorphism_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!("test automorphism_inplace digits: {} rank: {}", di, rank); - test_automorphism_inplace(log_n, basek, -5, k_ct, k_ksk, di, rank, 3.2); - }); - }); -} - -#[test] -fn automorphism() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_out: usize = k_ksk; // Better capture noise. - println!("test automorphism digits: {} rank: {}", di, rank); - test_automorphism(log_n, basek, -5, k_out, k_in, k_ksk, di, rank, 3.2); - }) - }); -} - -fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); - - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_pt); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0xFF); - - pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10); - - ct.encrypt_sk( - &module, - &pt, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - pt.data.zero(); - - ct.decrypt(&module, &mut pt, &sk, scratch.borrow()); - - let mut data_have: Vec = vec![0i64; module.n()]; - - pt.data - .decode_vec_i64(0, basek, pt.size() * basek, &mut data_have); - - // TODO: properly assert the decryption noise through std(dec(ct) - pt) - let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64; - izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { - let b_scaled = (*b as f64) / scale; - assert!( - (*a as f64 - b_scaled).abs() < 0.1, - "{} {}", - *a as f64, - b_scaled - ) - }); -} - -fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([1u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) - | GLWECiphertextFourier::encrypt_sk_scratch_space(&module, basek, k_ct, rank), - ); - - ct_dft.encrypt_zero_sk( - &module, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - ct_dft.decrypt(&module, &mut pt, &sk, scratch.borrow()); - - assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); -} - -fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) { - let module: Module = Module::::new(1 << log_n); - - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - let mut source_xu: Source = Source::new([0u8; 32]); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::alloc(&module, basek, k_pk, rank); - pk.generate_from_sk(&module, &sk, &mut source_xa, &mut source_xe, sigma); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::encrypt_pk_scratch_space(&module, basek, pk.k()), - ); - - let mut data_want: Vec = vec![0i64; module.n()]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0); - - pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10); - - ct.encrypt_pk( - &module, - &pt_want, - &pk, - &mut source_xu, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); - - let noise_have: f64 = pt_want.data.std(0, basek).log2(); - let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); - - assert!( - (noise_have - noise_want).abs() < 0.2, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_keyswitch( - log_n: usize, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ksk: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_out) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | GLWECiphertext::keyswitch_scratch_space( - &module, - basek, - ct_out.k(), - ct_in.k(), - ksk.k(), - digits, - rank_in, - rank_out, - ), - ); - - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ksk.generate_from_sk( - &module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.encrypt_sk( - &module, - &pt_want, - &sk_in, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow()); - ct_out.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank_in as f64, - k_in, - k_ksk, - ); - - println!("{} vs. {}", noise_have, noise_want); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ct_grlwe.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_grlwe.k(), digits, rank), - ); - - let mut sk0: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk0.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut sk1: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk1.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ct_grlwe.generate_from_sk( - &module, - &sk0, - &sk1, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.encrypt_sk( - &module, - &pt_want, - &sk0, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - - ct_glwe.decrypt(&module, &mut pt_have, &sk1, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_ct, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_automorphism( - log_n: usize, - basek: usize, - p: i64, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | GLWECiphertext::automorphism_scratch_space( - &module, - basek, - ct_out.k(), - ct_in.k(), - autokey.k(), - digits, - rank, - ), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - autokey.generate_from_sk( - &module, - p, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); - ct_out.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - println!("{}", noise_have); - - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_in, - k_ksk, - ); - - assert!( - noise_have <= noise_want + 1.0, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_automorphism_inplace( - log_n: usize, - basek: usize, - p: i64, - k_ct: usize, - k_ksk: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::automorphism_inplace_scratch_space(&module, basek, ct.k(), autokey.k(), digits, rank), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - autokey.generate_from_sk( - &module, - p, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.automorphism_inplace(&module, &autokey, scratch.borrow()); - ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow()); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_ct, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_external_product( - log_n: usize, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe_in.k()) - | GLWECiphertext::external_product_scratch_space( - &module, - basek, - ct_glwe_out.k(), - ct_glwe_in.k(), - ct_ggsw.k(), - digits, - rank, - ), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe_in.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe_out.external_product(&module, &ct_glwe_in, &ct_ggsw, scratch.borrow()); - - ct_glwe_out.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_in, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); - - ct_glwe.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_ct, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs deleted file mode 100644 index 48a0f0d..0000000 --- a/core/src/test_fft64/glwe_fourier.rs +++ /dev/null @@ -1,478 +0,0 @@ -use crate::{ - GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, - test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, -}; -use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; -use sampling::source::Source; - -#[test] -fn keyswitch() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - println!( - "test keyswitch digits: {} rank_in: {} rank_out: {}", - di, rank_in, rank_out - ); - let k_out: usize = k_ksk; // Better capture noise. - test_keyswitch(log_n, basek, k_in, k_out, k_ksk, di, rank_in, rank_out, 3.2); - }) - }); - }); -} - -#[test] -fn keyswitch_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 45; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!("test keyswitch_inplace digits: {} rank: {}", di, rank); - test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); - }); - }); -} - -#[test] -fn external_product() { - let log_n: usize = 8; - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; - println!("test external_product digits: {} rank: {}", di, rank); - let k_out: usize = k_ggsw; // Better capture noise. - test_external_product(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); - }); - }); -} - -#[test] -fn external_product_inplace() { - let log_n: usize = 8; - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; - println!("test external_product digits: {} rank: {}", di, rank); - test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); - }); - }); -} - -fn test_keyswitch( - log_n: usize, - basek: usize, - k_in: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ksk: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); - let mut ct_glwe_dft_in: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_in, rank_in); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); - let mut ct_glwe_dft_out: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) - | GLWECiphertext::decrypt_scratch_space(&module, basek, k_out) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_in) - | GLWECiphertextFourier::keyswitch_scratch_space( - &module, - basek, - ct_glwe_out.k(), - ksk.k(), - ct_glwe_in.k(), - digits, - rank_in, - rank_out, - ), - ); - - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ksk.generate_from_sk( - &module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe_in.encrypt_sk( - &module, - &pt_want, - &sk_in, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe_in.dft(&module, &mut ct_glwe_dft_in); - ct_glwe_dft_out.keyswitch(&module, &ct_glwe_dft_in, &ksk, scratch.borrow()); - ct_glwe_dft_out.idft(&module, &mut ct_glwe_out, scratch.borrow()); - - ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank_in as f64, - k_in, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), digits, rank), - ); - - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); - - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ksk.generate_from_sk( - &module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.encrypt_sk( - &module, - &pt_want, - &sk_in, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_glwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.keyswitch_inplace(&module, &ksk, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct_glwe, scratch.borrow()); - - ct_glwe.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - sigma * sigma, - 0f64, - rank as f64, - k_ct, - k_ksk, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_external_product( - log_n: usize, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, - sigma: f64, -) { - let module: Module = Module::::new(1 << log_n); - - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); - let mut ct_in_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_in, rank); - let mut ct_out_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: i64 = 1; - - pt_rgsw.raw_mut()[0] = 1; // X^{0} - module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | GLWECiphertextFourier::external_product_scratch_space( - &module, - basek, - ct_out.k(), - ct_in.k(), - ct_ggsw.k(), - digits, - rank, - ), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct_in.dft(&module, &mut ct_in_dft); - ct_out_dft.external_product(&module, &ct_in_dft, &ct_ggsw, scratch.borrow()); - ct_out_dft.idft(&module, &mut ct_out, scratch.borrow()); - - ct_out.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - - pt_want.rotate_inplace(&module, k); - pt_have.sub_inplace_ab(&module, &pt_want); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_in, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); -} - -fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { - let module: Module = Module::::new(1 << log_n); - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); - let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - pt_want - .data - .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: i64 = 1; - - pt_rgsw.raw_mut()[0] = 1; // X^{0} - module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertextFourier::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), digits, rank), - ); - - let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); - sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.encrypt_sk( - &module, - &pt_want, - &sk, - &mut source_xa, - &mut source_xe, - sigma, - scratch.borrow(), - ); - - ct.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); - ct_rlwe_dft.idft(&module, &mut ct, scratch.borrow()); - - ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); - - pt_want.rotate_inplace(&module, k); - pt_have.sub_inplace_ab(&module, &pt_want); - - let noise_have: f64 = pt_have.data.std(0, basek).log2(); - - let var_gct_err_lhs: f64 = sigma * sigma; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} - let var_a0_err: f64 = sigma * sigma; - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_want: f64 = noise_ggsw_product( - module.n() as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_ct, - k_ggsw, - ); - - assert!( - (noise_have - noise_want).abs() <= 0.5, - "{} {}", - noise_have, - noise_want - ); - - println!("{} {}", noise_have, noise_want); -}