diff --git a/backend/examples/rlwe_encrypt.rs b/backend/examples/rlwe_encrypt.rs index 84f85ad..c56496c 100644 --- a/backend/examples/rlwe_encrypt.rs +++ b/backend/examples/rlwe_encrypt.rs @@ -40,7 +40,7 @@ fn main() { let mut buf_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_size); - module.vec_znx_dft(&mut buf_dft, 0, &ct, 1); + module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1); // Applies DFT(ct[1]) * DFT(s) module.svp_apply_inplace( @@ -102,7 +102,7 @@ fn main() { // Decryption // DFT(ct[1] * s) - module.vec_znx_dft(&mut buf_dft, 0, &ct, 1); + module.vec_znx_dft(1, 0, &mut buf_dft, 0, &ct, 1); module.svp_apply_inplace( &mut buf_dft, 0, // Selects the first column of res. diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs index 6af4455..d831f73 100644 --- a/backend/src/mat_znx_dft_ops.rs +++ b/backend/src/mat_znx_dft_ops.rs @@ -393,7 +393,7 @@ mod tests { let mut source: Source = Source::new([0u8; 32]); (0..mat_cols_out).for_each(|col_out| { a.fill_uniform(basek, col_out, mat_size, &mut source); - module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); + module.vec_znx_dft(1, 0, &mut a_dft, col_out, &a, col_out); }); module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft); module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in); @@ -453,7 +453,7 @@ mod tests { (0..mat_cols_out).for_each(|col_out_i| { let idx = 1 + col_in_i * mat_cols_out + col_out_i; tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} - module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i); + module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; }); module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); @@ -462,7 +462,7 @@ mod tests { let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); (0..a_cols).for_each(|i| { - module.vec_znx_dft(&mut a_dft, i, &a, i); + module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i); }); module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow()); @@ -489,14 +489,15 @@ mod tests { #[test] fn vmp_apply_add() { - let log_n: i32 = 5; + let log_n: i32 = 4; let n: usize = 1 << log_n; let module: Module = Module::::new(n); let basek: usize = 8; - let a_size: usize = 6; - let mat_size: usize = 6; + let a_size: usize = 5; + let mat_size: usize = 5; let res_size: usize = a_size; + let mut source: Source = Source::new([0u8; 32]); [1, 2].iter().for_each(|in_cols| { [1, 2].iter().for_each(|out_cols| { @@ -521,10 +522,8 @@ mod tests { let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); - (0..a_cols).for_each(|i| { - (0..a_size).for_each(|j| { - a.at_mut(i, j)[i + 1] = 1 + j as i64; - }); + (0..a_cols).for_each(|col_i| { + a.fill_uniform(basek, col_i, a.size(), &mut source); }); let mut mat_znx_dft: MatZnxDft, FFT64> = @@ -539,9 +538,9 @@ mod tests { (0..a.size()).for_each(|row_i| { (0..mat_cols_in).for_each(|col_in_i| { (0..mat_cols_out).for_each(|col_out_i| { - let idx = 1 + col_in_i * mat_cols_out + col_out_i; + let idx: usize = 1 + col_in_i * mat_cols_out + col_out_i; tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} - module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i); + module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; }); module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); @@ -550,12 +549,12 @@ mod tests { let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); (0..a_cols).for_each(|i| { - module.vec_znx_dft(&mut a_dft, i, &a, i); + module.vec_znx_dft(1, 0, &mut a_dft, i, &a, i); }); c_dft.zero(); (0..c_dft.cols()).for_each(|i| { - module.vec_znx_dft(&mut c_dft, i, &a, 0); + module.vec_znx_dft(1, 0, &mut c_dft, i, &a, 0); }); module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, shift, scratch.borrow()); @@ -582,16 +581,118 @@ mod tests { ); (0..res_cols).for_each(|i| { module.vec_znx_add_inplace(&mut res_want, i, &a, 0); + module.vec_znx_normalize_inplace(basek, &mut res_want, i, scratch.borrow()); }); - let mut res_have_vi64: Vec = vec![i64::default(); n]; - let mut res_want_vi64: Vec = vec![i64::default(); n]; + assert_eq!(res_want, res_have); + }); + }); + }); + } - (0..mat_cols_out).for_each(|col_i| { - res_want.decode_vec_i64(col_i, basek, basek * a_size, &mut res_want_vi64); - res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64); - assert_eq!(res_have_vi64, res_want_vi64); + #[test] + fn vmp_apply_digits() { + let log_n: i32 = 4; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let basek: usize = 8; + let a_size: usize = 6; + let mat_size: usize = 6; + let res_size: usize = a_size; + + [1, 2].iter().for_each(|in_cols| { + [1, 2].iter().for_each(|out_cols| { + [1, 3, 6].iter().for_each(|digits| { + let mut source: Source = Source::new([0u8; 32]); + + let a_cols: usize = *in_cols; + let res_cols: usize = *out_cols; + + let mat_rows: usize = a_size; + let mat_cols_in: usize = a_cols; + let mat_cols_out: usize = res_cols; + + let mut scratch: ScratchOwned = ScratchOwned::new( + module.vmp_apply_tmp_bytes( + res_size, + a_size, + mat_rows, + mat_cols_in, + mat_cols_out, + mat_size, + ) | module.vec_znx_big_normalize_tmp_bytes(), + ); + + let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); + + (0..a_cols).for_each(|col_i| { + a.fill_uniform(basek, col_i, a.size(), &mut source); }); + + let mut mat_znx_dft: MatZnxDft, FFT64> = + module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + + let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); + + let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); + + let rows: usize = a.size() / digits; + + let shift: usize = 1; + + // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. + (0..rows).for_each(|row_i| { + (0..mat_cols_in).for_each(|col_in_i| { + (0..mat_cols_out).for_each(|col_out_i| { + let idx: usize = shift + col_in_i * mat_cols_out + col_out_i; + let limb: usize = (digits - 1) + row_i * digits; + tmp.at_mut(col_out_i, limb)[idx] = 1 as i64; // X^{idx} + module.vec_znx_dft(1, 0, &mut c_dft, col_out_i, &tmp, col_out_i); + tmp.at_mut(col_out_i, limb)[idx] = 0 as i64; + }); + module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + }); + }); + + let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, (a_size + digits - 1) / digits); + + (0..*digits).for_each(|di| { + (0..a_cols).for_each(|col_i| { + module.vec_znx_dft(digits - 1 - di, *digits, &mut a_dft, col_i, &a, col_i); + }); + + if di == 0 { + module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow()); + } else { + module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, di, scratch.borrow()); + } + }); + + let mut res_have: VecZnx> = module.new_vec_znx(res_cols, mat_size); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); + }); + + let mut res_want: VecZnx> = module.new_vec_znx(res_cols, mat_size); + let mut tmp: VecZnx> = module.new_vec_znx(res_cols, mat_size); + (0..res_cols).for_each(|col_i| { + (0..a_cols).for_each(|j| { + module.vec_znx_rotate( + (col_i + j * mat_cols_out + shift) as i64, + &mut tmp, + 0, + &a, + j, + ); + module.vec_znx_add_inplace(&mut res_want, col_i, &tmp, 0); + }); + module.vec_znx_normalize_inplace(basek, &mut res_want, col_i, scratch.borrow()); + }); + + assert_eq!(res_have, res_want) }); }); }); diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index cbd1d8c..84b9a84 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -23,6 +23,7 @@ use std::{cmp::min, fmt}; /// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// are small polynomials of Zn\[X\]. +#[derive(PartialEq, Eq)] pub struct VecZnx { pub data: D, pub n: usize, @@ -30,6 +31,15 @@ pub struct VecZnx { pub size: usize, } +impl fmt::Debug for VecZnx +where + D: AsRef<[u8]>, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + impl ZnxInfos for VecZnx { fn cols(&self) -> usize { self.cols diff --git a/backend/src/vec_znx_dft.rs b/backend/src/vec_znx_dft.rs index 516228a..82b2cf4 100644 --- a/backend/src/vec_znx_dft.rs +++ b/backend/src/vec_znx_dft.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; use crate::{ - Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned, + Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned, }; use std::fmt; @@ -91,39 +91,6 @@ impl>, B: Backend> VecZnxDft { } } -impl + AsRef<[u8]>> VecZnxDft -where - VecZnxDft: VecZnxDftToMut, -{ - /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. - pub fn extract_column>(&mut self, self_col: usize, a: &VecZnxDft, a_col: usize) - where - VecZnxDft: VecZnxDftToRef, - { - #[cfg(debug_assertions)] - { - assert!(self_col < self.cols()); - assert!(a_col < a.cols()); - } - - let min_size: usize = self.size.min(a.size()); - let max_size: usize = self.size; - - let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - - (0..min_size).for_each(|i: usize| { - self_mut - .at_mut(self_col, i) - .copy_from_slice(a_ref.at(a_col, i)); - }); - - (min_size..max_size).for_each(|i| { - self_mut.zero_at(self_col, i); - }); - } -} - pub type VecZnxDftOwned = VecZnxDft, B>; impl VecZnxDft { diff --git a/backend/src/vec_znx_dft_ops.rs b/backend/src/vec_znx_dft_ops.rs index e4d6c33..6bfa9fb 100644 --- a/backend/src/vec_znx_dft_ops.rs +++ b/backend/src/vec_znx_dft_ops.rs @@ -53,7 +53,7 @@ pub trait VecZnxDftOps { R: VecZnxDftToMut, A: VecZnxDftToRef; - fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef; @@ -74,9 +74,9 @@ pub trait VecZnxDftOps { R: VecZnxBigToMut, A: VecZnxDftToRef; - fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_dft(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where - R: VecZnxDftToMut, + R: VecZnxDftToMut, A: VecZnxToRef; } @@ -150,7 +150,7 @@ impl VecZnxDftOps for Module { } } - fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_dft_copy(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, @@ -158,14 +158,18 @@ impl VecZnxDftOps for Module { let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - let min_size: usize = min(res_mut.size(), a_ref.size()); + let steps: usize = (a_ref.size() + step - 1) / step; + let min_steps: usize = min(res_mut.size(), steps); - (0..min_size).for_each(|j| { - res_mut + (0..min_steps).for_each(|j| { + let limb: usize = offset + j * step; + if limb < a_ref.size(){ + res_mut .at_mut(res_col, j) - .copy_from_slice(a_ref.at(a_col, j)); + .copy_from_slice(a_ref.at(a_col, limb)); + } }); - (min_size..res_mut.size()).for_each(|j| { + (min_steps..res_mut.size()).for_each(|j| { res_mut.zero_at(res_col, j); }) } @@ -224,32 +228,30 @@ impl VecZnxDftOps for Module { unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } } - /// b <- DFT(a) - /// - /// # Panics - /// If b.cols < a_col - fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_dft(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxToRef, { let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); let a_ref: crate::VecZnx<&[u8]> = a.to_ref(); - - let min_size: usize = min(res_mut.size(), a_ref.size()); - + let steps: usize = (a_ref.size() + step - 1) / step; + let min_steps: usize = min(res_mut.size(), steps); unsafe { - (0..min_size).for_each(|j| { - vec_znx_dft::vec_znx_dft( - self.ptr, - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1 as u64, - a_ref.at_ptr(a_col, j), - 1 as u64, - a_ref.sl() as u64, - ) + (0..min_steps).for_each(|j| { + let limb: usize = offset + j * step; + if limb < a_ref.size() { + vec_znx_dft::vec_znx_dft( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + a_ref.at_ptr(a_col, limb), + 1 as u64, + a_ref.sl() as u64, + ) + } }); - (min_size..res_mut.size()).for_each(|j| { + (min_steps..res_mut.size()).for_each(|j| { res_mut.zero_at(res_col, j); }); } diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index 2a51387..a1f408f 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -24,11 +24,12 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let k_ct_out: usize = p.k_ct_out; let k_ggsw: usize = p.k_ggsw; let rank: usize = p.rank; + let digits: usize = 1; let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek; let sigma: f64 = 3.2; - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_in, rank); let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_out, rank); let pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); @@ -42,6 +43,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { ct_glwe_out.k(), ct_glwe_in.k(), ct_ggsw.k(), + digits, rank, ), ); @@ -118,18 +120,19 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let k_glwe: usize = p.k_ct; let k_ggsw: usize = p.k_ggsw; let rank: usize = p.rank; + let digits: usize = 1; let rows: usize = (p.k_ct + p.basek - 1) / p.basek; let sigma: f64 = 3.2; - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_glwe, rank); let pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut scratch = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), rank), + | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), ); let mut source_xs = Source::new([0u8; 32]); diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs index 1841d5b..69651fd 100644 --- a/core/benches/keyswitch_glwe_fft64.rs +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -26,11 +26,13 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let k_grlwe: usize = p.k_ksk; let rank_in: usize = p.rank_in; let rank_out: usize = p.rank_out; + let digits: usize = 1; let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek; let sigma: f64 = 3.2; - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_grlwe, rows, rank_in, rank_out); + let mut ksk: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_grlwe, rows, digits, rank_in, rank_out); let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_in, rank_in); let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_rlwe_out, rank_out); @@ -43,6 +45,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { ct_out.k(), ct_in.k(), ksk.k(), + digits, rank_in, rank_out, ), @@ -124,17 +127,18 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let k_ct: usize = p.k_ct; let k_ksk: usize = p.k_ksk; let rank: usize = p.rank; + let digits: usize = 1; let rows: usize = (p.k_ct + p.basek - 1) / p.basek; let sigma: f64 = 3.2; - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); let mut scratch = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), rank, ksk.k()), + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), digits, rank), ); let mut source_xs: Source = Source::new([0u8; 32]); diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index a4165e6..e9e4a0b 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -20,7 +20,7 @@ impl AutomorphismKey, FFT64> { } pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits,rank, rank) + GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits, rank, rank) } } @@ -98,68 +98,74 @@ impl AutomorphismKey, FFT64> { pub fn keyswitch_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, - ksk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, rank: usize, ) -> usize { - GLWESwitchingKey::keyswitch_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k) + GLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) } pub fn keyswitch_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - out_rank: usize, - ksk_k: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, ) -> usize { - GLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, out_k, out_rank, ksk_k) + GLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) } pub fn automorphism_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, - atk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, rank: usize, ) -> usize { - let tmp_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, in_k, rank); - let tmp_idft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); + let tmp_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank); + let tmp_idft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); let idft: usize = module.vec_znx_idft_tmp_bytes(); - let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, out_k, rank, atk_k); + let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); tmp_dft + tmp_idft + idft + keyswitch } pub fn automorphism_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - ksk_k: usize, + k_out: usize, + k_ksk: usize, + digits: usize, rank: usize, ) -> usize { - AutomorphismKey::automorphism_scratch_space(module, basek, out_k, out_k, ksk_k, rank) + AutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank) } pub fn external_product_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, + k_out: usize, + k_in: usize, ggsw_k: usize, + digits: usize, rank: usize, ) -> usize { - GLWESwitchingKey::external_product_scratch_space(module, basek, out_k, in_k, ggsw_k, rank) + GLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, ggsw_k, digits, rank) } pub fn external_product_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, + k_out: usize, ggsw_k: usize, + digits: usize, rank: usize, ) -> usize { - GLWESwitchingKey::external_product_inplace_scratch_space(module, basek, out_k, ggsw_k, rank) + GLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_out, ggsw_k, digits, rank) } } @@ -284,7 +290,7 @@ impl + AsRef<[u8]>> AutomorphismKey { // and switches back to DFT domain (0..self.rank_out() + 1).for_each(|i| { module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i); - module.vec_znx_dft(&mut tmp_dft.data, i, &tmp_idft.data, i); + module.vec_znx_dft(1, 0, &mut tmp_dft.data, i, &tmp_idft.data, i); }); // Sets back the relevant row diff --git a/core/src/elem.rs b/core/src/elem.rs index 9a3a285..c7c939f 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -34,7 +34,7 @@ pub trait Infos { /// Returns the number of size per polynomial. fn size(&self) -> usize { let size: usize = self.inner().size(); - debug_assert_eq!(size, div_ceil(self.basek(), self.k())); + debug_assert_eq!(size, div_ceil(self.k(), self.basek())); size } diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 1650f59..7f23071 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -14,17 +14,65 @@ pub struct GGLWECiphertext { } impl GGLWECiphertext, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn alloc( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> Self { + let size: usize = div_ceil(k, basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + Self { - data: module.new_mat_znx_dft(div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k)), + data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, size), basek: basek, k, digits, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize { - module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k)) + pub fn bytes_of( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { + let size: usize = div_ceil(k, basek); + debug_assert!( + size > digits, + "invalid gglwe: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + module.bytes_of_mat_znx_dft(rows, rank_in, rank_out + 1, rows) } } @@ -49,7 +97,7 @@ impl GGLWECiphertext { self.data.cols_out() - 1 } - pub fn digits(&self) -> usize{ + pub fn digits(&self) -> usize { self.digits } @@ -64,7 +112,7 @@ impl GGLWECiphertext { impl GGLWECiphertext, FFT64> { pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - let size = div_ceil(basek, k); + let size = div_ceil(k, basek); GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + module.bytes_of_vec_znx(rank + 1, size) + module.bytes_of_vec_znx(1, size) @@ -103,7 +151,16 @@ impl + AsRef<[u8]>> GGLWECiphertext { self.rank(), self.size(), GGLWECiphertext::generate_from_sk_scratch_space(module, self.basek(), self.k(), self.rank()) - ) + ); + assert!( + self.rows() * self.digits() * self.basek() <= self.k(), + "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}", + self.rows(), + self.digits(), + self.basek(), + self.rows() * self.digits() * self.basek(), + self.k() + ); } let rows: usize = self.rows(); @@ -132,7 +189,13 @@ impl + AsRef<[u8]>> GGLWECiphertext { (0..rows).for_each(|row_i| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, col_i); + module.vec_znx_add_scalar_inplace( + &mut tmp_pt.data, + 0, + (digits - 1) + row_i * digits, + pt, + col_i, + ); module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3); // rlwe encrypt of vec_znx_pt into vec_znx_ct diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 98edf89..b15f6fc 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -19,8 +19,24 @@ pub struct GGSWCiphertext { impl GGSWCiphertext, B> { pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { + let size: usize = div_ceil(k, basek); + debug_assert!( + size > digits, + "invalid ggsw: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + Self { - data: module.new_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)), + data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, div_ceil(k, basek)), basek, k: k, digits, @@ -28,7 +44,23 @@ impl GGSWCiphertext, B> { } pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)) + let size: usize = div_ceil(k, basek); + debug_assert!( + size > digits, + "invalid ggsw: ceil(k/basek): {} <= digits: {}", + size, + digits + ); + + assert!( + rows * digits <= size, + "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", + rows, + digits, + size + ); + + module.bytes_of_mat_znx_dft(rows, rank + 1, rank + 1, size) } } @@ -60,7 +92,7 @@ impl GGSWCiphertext { impl GGSWCiphertext, FFT64> { pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - let size = div_ceil(basek, k); + let size = div_ceil(k, basek); GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + module.bytes_of_vec_znx(rank + 1, size) + module.bytes_of_vec_znx(1, size) @@ -71,46 +103,59 @@ impl GGSWCiphertext, FFT64> { module: &Module, basek: usize, self_k: usize, - tsk_k: usize, + k_tsk: usize, + digits: usize, rank: usize, ) -> usize { - let tsk_size: usize = div_ceil(basek, tsk_k); - let self_size: usize = div_ceil(basek, self_k); + let tsk_size: usize = div_ceil(k_tsk, basek); + let self_size_out: usize = div_ceil(self_k, basek); + let self_size_in: usize = div_ceil(self_size_out, digits); let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); - let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size); - let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size); + let tmp_a: usize = module.bytes_of_vec_znx_dft(1, self_size_in); + let vmp: usize = module.vmp_apply_tmp_bytes( + self_size_out, + self_size_in, + self_size_in, + rank, + rank, + tsk_size, + ); let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); - tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm)) + tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) } pub(crate) fn keyswitch_internal_col0_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, - ksk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, rank: usize, ) -> usize { - GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k) - + module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, in_k)) + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + + module.bytes_of_vec_znx_dft(rank + 1, div_ceil(k_in, basek)) } pub fn keyswitch_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, - ksk_k: usize, - tsk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits_ksk: usize, + k_tsk: usize, + digits_tsk: usize, rank: usize, ) -> usize { - let out_size: usize = div_ceil(basek, out_k); + let out_size: usize = div_ceil(k_out, basek); let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size); let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); - let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, out_k, in_k, ksk_k, rank); - let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, out_k, tsk_k, rank); + let ks: usize = + GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); res_znx + ci_dft + (ks | expand_rows | res_dft) } @@ -118,67 +163,81 @@ impl GGSWCiphertext, FFT64> { pub fn keyswitch_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - ksk_k: usize, - tsk_k: usize, + k_out: usize, + k_ksk: usize, + digits_ksk: usize, + k_tsk: usize, + digits_tsk: usize, rank: usize, ) -> usize { - GGSWCiphertext::keyswitch_scratch_space(module, basek, out_k, out_k, ksk_k, tsk_k, rank) + GGSWCiphertext::keyswitch_scratch_space( + module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, + ) } pub fn automorphism_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, - atk_k: usize, - tsk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits_ksk: usize, + k_tsk: usize, + digits_tsk: usize, rank: usize, ) -> usize { let cols: usize = rank + 1; - let out_size: usize = div_ceil(basek, out_k); + let out_size: usize = div_ceil(k_out, basek); let res: usize = module.bytes_of_vec_znx(cols, out_size); let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); - let ks_internal: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, out_k, in_k, atk_k, rank); - let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, out_k, tsk_k, rank); + let ks_internal: usize = + GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank); + let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); res + ci_dft + (ks_internal | expand | res_dft) } pub fn automorphism_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - atk_k: usize, - tsk_k: usize, + k_out: usize, + k_ksk: usize, + digits_ksk: usize, + k_tsk: usize, + digits_tsk: usize, rank: usize, ) -> usize { - GGSWCiphertext::automorphism_scratch_space(module, basek, out_k, out_k, atk_k, tsk_k, rank) + GGSWCiphertext::automorphism_scratch_space( + module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, + ) } pub fn external_product_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, - ggsw_k: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, rank: usize, ) -> usize { - let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, in_k, rank); - let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); - let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, out_k, in_k, ggsw_k, rank); + let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank); + let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); tmp_in + tmp_out + ggsw } pub fn external_product_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - ggsw_k: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, rank: usize, ) -> usize { - let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); - let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, out_k, ggsw_k, rank); + let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let ggsw: usize = + GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); tmp + ggsw } } @@ -214,7 +273,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { tmp_pt.data.zero(); // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, 0); + module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0); module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2); (0..rank + 1).for_each(|col_j| { @@ -254,7 +313,15 @@ impl + AsRef<[u8]>> GGSWCiphertext { let cols: usize = self.rank() + 1; assert!( - scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self.basek(), self.k(), tsk.k(), self.rank()) + scratch.available() + >= GGSWCiphertext::expand_row_scratch_space( + module, + self.basek(), + self.k(), + tsk.k(), + tsk.digits(), + tsk.rank() + ) ); // Example for rank 3: @@ -279,8 +346,6 @@ impl + AsRef<[u8]>> GGSWCiphertext { let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size()); { - let (mut tmp_dft_col_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, self.size()); - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 // // # Example for col=1 @@ -293,23 +358,27 @@ impl + AsRef<[u8]>> GGSWCiphertext { // = // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) (1..cols).for_each(|col_i| { - // Extracts a[i] and multipies with Enc(s[i]s[j]) - tmp_dft_col_data.extract_column(0, ci_dft, col_i); + let digits: usize = tsk.digits(); + let pmat: &MatZnxDft = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j]) + + // Extracts a[i] and multipies with Enc(s[i]s[j]) if col_i == 1 { - module.vmp_apply( - &mut tmp_dft_i, - &tmp_dft_col_data, - &tsk.at(col_i - 1, col_j - 1).0.data, // Selects Enc(s[i]s[j]) - scratch2, - ); + (0..digits).for_each(|di| { + let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + di) / digits); + module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); + if di == 0 { + module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2); + } else { + module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2); + } + }); } else { - module.vmp_apply_add( - &mut tmp_dft_i, - &tmp_dft_col_data, - &tsk.at(col_i - 1, col_j - 1).0.data, // Selects Enc(s[i]s[j]) - scratch2, - ); + (0..digits).for_each(|di| { + let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + di) / digits); + module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i); + module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2); + }); } }); } @@ -344,7 +413,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { let basek: usize = self.basek(); let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank); - let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); // Keyswitch the j-th row of the col 0 (0..lhs.rows()).for_each(|row_i| { @@ -354,7 +423,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { // Isolates DFT(a[i]) (0..cols).for_each(|col_i| { - module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i); + module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i); }); module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); @@ -425,8 +494,10 @@ impl + AsRef<[u8]>> GGSWCiphertext { self.k(), lhs.k(), auto_key.k(), + auto_key.digits(), tensor_key.k(), - self.rank() + tensor_key.digits(), + self.rank(), ) ) }; @@ -436,7 +507,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { let basek: usize = self.basek(); let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank); - let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); // Keyswitch the j-th row of the col 0 (0..lhs.rows()).for_each(|row_i| { @@ -448,7 +519,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { (0..cols).for_each(|col_i| { // (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2) module.vec_znx_automorphism_inplace(auto_key.p(), &mut tmp_res.data, col_i); - module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i); + module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i); }); module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); @@ -510,6 +581,19 @@ impl + AsRef<[u8]>> GGSWCiphertext { self.rank(), rhs.rank() ); + + assert!( + scratch.available() + >= GGSWCiphertext::external_product_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank() + ) + ) } let (mut tmp_ct_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); @@ -582,6 +666,7 @@ impl> GGSWCiphertext { res.k(), self.k(), ksk.k(), + ksk.digits(), ksk.rank() ) ) diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index a084b91..38e639e 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -19,14 +19,14 @@ pub struct GLWECiphertext { impl GLWECiphertext> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.new_vec_znx(rank + 1, div_ceil(basek, k)), + data: module.new_vec_znx(rank + 1, div_ceil(k, basek)), basek, k, } } pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k)) + module.bytes_of_vec_znx(rank + 1, div_ceil(k, basek)) } } @@ -62,43 +62,44 @@ impl> GLWECiphertext { } (0..self.rank() + 1).for_each(|i| { - module.vec_znx_dft(&mut res.data, i, &self.data, i); + module.vec_znx_dft(1, 0, &mut res.data, i, &self.data, i); }) } } impl GLWECiphertext> { pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = div_ceil(basek, k); + let size: usize = div_ceil(k, basek); module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size) } pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = div_ceil(basek, k); + let size: usize = div_ceil(k, basek); ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1)) + module.bytes_of_scalar_znx_dft(1) + module.vec_znx_big_normalize_tmp_bytes() } pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = div_ceil(basek, k); + let size: usize = div_ceil(k, basek); (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } pub fn keyswitch_scratch_space( module: &Module, basek: usize, - out_k: usize, - out_rank: usize, - in_k: usize, - in_rank: usize, - ksk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, ) -> usize { - let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, out_rank); - let in_size: usize = div_ceil(basek, in_k); - let out_size: usize = div_ceil(basek, out_k); - let ksk_size: usize = div_ceil(basek, ksk_k); - let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) - + module.bytes_of_vec_znx_dft(in_rank, in_size); + let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank_out); + let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); + let out_size: usize = div_ceil(k_out, basek); + let ksk_size: usize = div_ceil(k_ksk, basek); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) + + module.bytes_of_vec_znx_dft(rank_in, in_size); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); return res_dft + (vmp | normalize); } @@ -106,58 +107,63 @@ impl GLWECiphertext> { pub fn keyswitch_from_fourier_scratch_space( module: &Module, basek: usize, - out_k: usize, - out_rank: usize, - in_k: usize, - in_rank: usize, - ksk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, ) -> usize { - Self::keyswitch_scratch_space(module, basek, out_k, out_rank, in_k, in_rank, ksk_k) + Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) } pub fn keyswitch_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - out_rank: usize, - ksk_k: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, ) -> usize { - Self::keyswitch_scratch_space(module, basek, out_k, out_rank, out_k, out_rank, ksk_k) + Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) } pub fn automorphism_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, - atk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, rank: usize, ) -> usize { - Self::keyswitch_scratch_space(module, basek, out_k, rank, in_k, rank, atk_k) + Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) } pub fn automorphism_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - atk_k: usize, + k_out: usize, + k_ksk: usize, + digits: usize, rank: usize, ) -> usize { - Self::keyswitch_scratch_space(module, basek, out_k, rank, out_k, rank, atk_k) + Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) } pub fn external_product_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, + k_out: usize, + k_in: usize, ggsw_k: usize, + digits: usize, rank: usize, ) -> usize { - let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); - let in_size: usize = div_ceil(basek, in_k); - let out_size: usize = div_ceil(basek, out_k); - let ggsw_size: usize = div_ceil(basek, ggsw_k); + let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); + let out_size: usize = div_ceil(k_out, basek); + let ggsw_size: usize = div_ceil(ggsw_k, basek); let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + module.vmp_apply_tmp_bytes( out_size, @@ -174,11 +180,12 @@ impl GLWECiphertext> { pub fn external_product_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, + k_out: usize, ggsw_k: usize, + digits: usize, rank: usize, ) -> usize { - Self::external_product_scratch_space(module, basek, out_k, out_k, ggsw_k, rank) + Self::external_product_scratch_space(module, basek, k_out, k_out, ggsw_k, digits, rank) } } @@ -390,10 +397,11 @@ impl + AsMut<[u8]>> GLWECiphertext { module, self.basek(), self.k(), - self.rank(), lhs.k(), - lhs.rank(), rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), ) ); } @@ -405,12 +413,29 @@ impl + AsMut<[u8]>> GLWECiphertext { let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise { - // Applies VMP - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft_copy(&mut ai_dft, col_i, &lhs.data, col_i + 1); + let digits = rhs.digits(); + + (0..digits).for_each(|di| { + // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + di) / digits); + + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft_copy( + digits, + digits - 1 - di, + &mut ai_dft, + col_i, + &lhs.data, + col_i + 1, + ); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); + } }); - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); } module.vec_znx_dft_add_inplace(&mut res_dft, 0, &lhs.data, 0); @@ -458,10 +483,11 @@ impl + AsMut<[u8]>> GLWECiphertext { module, self.basek(), self.k(), - self.rank(), lhs.k(), - lhs.rank(), rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), ) ); } @@ -472,11 +498,29 @@ impl + AsMut<[u8]>> GLWECiphertext { let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise { - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft(&mut ai_dft, col_i, &lhs.data, col_i + 1); + let digits = rhs.digits(); + + (0..digits).for_each(|di| { + // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, (lhs.size() + di) / digits); + + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft( + digits, + digits - 1 - di, + &mut ai_dft, + col_i, + &lhs.data, + col_i + 1, + ); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &ai_dft, &rhs.0.data, di, scratch2); + } }); - module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2); } let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); @@ -528,6 +572,18 @@ impl + AsMut<[u8]>> GLWECiphertext { assert_eq!(rhs.n(), module.n()); assert_eq!(self.n(), module.n()); assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::external_product_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank(), + ) + ); } let cols: usize = rhs.rank() + 1; @@ -535,11 +591,22 @@ impl + AsMut<[u8]>> GLWECiphertext { let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise { - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); - (0..cols).for_each(|col_i| { - module.vec_znx_dft(&mut a_dft, col_i, &lhs.data, col_i); + let digits = rhs.digits(); + + (0..digits).for_each(|di| { + // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + di) / digits); + + (0..cols).for_each(|col_i| { + module.vec_znx_dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); + } }); - module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); } let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); @@ -606,7 +673,7 @@ impl + AsMut<[u8]>> GLWECiphertext { self.data.fill_uniform(basek, i, size, source_xa); // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) - module.vec_znx_dft(&mut ci_dft, 0, &self.data, i); + module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); @@ -742,7 +809,7 @@ impl> GLWECiphertext { (1..cols).for_each(|i| { // ci_dft = DFT(a[i]) * DFT(s[i]) let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(&mut ci_dft, 0, &self.data, i); + module.vec_znx_dft(1, 0, &mut ci_dft, 0, &self.data, i); module.svp_apply_inplace(&mut ci_dft, 0, &sk.data_fourier, i - 1); let ci_big = module.vec_znx_idft_consume(ci_dft); diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index 4fe7e1e..c132a4a 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -15,14 +15,14 @@ pub struct GLWECiphertextFourier { impl GLWECiphertextFourier, B> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.new_vec_znx_dft(rank + 1, div_ceil(basek, k)), + data: module.new_vec_znx_dft(rank + 1, div_ceil(k, basek)), basek: basek, k: k, } } pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, k)) + module.bytes_of_vec_znx_dft(rank + 1, div_ceil(k, basek)) } } @@ -51,16 +51,16 @@ impl GLWECiphertextFourier { impl GLWECiphertextFourier, FFT64> { #[allow(dead_code)] pub(crate) fn idft_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - module.bytes_of_vec_znx(1, div_ceil(basek, k)) + module.bytes_of_vec_znx(1, div_ceil(k, basek)) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) } pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + module.bytes_of_vec_znx(rank + 1, div_ceil(k, basek)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) } pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = div_ceil(basek, k); + let size: usize = div_ceil(k, basek); (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size) | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) @@ -70,40 +70,45 @@ impl GLWECiphertextFourier, FFT64> { pub fn keyswitch_scratch_space( module: &Module, basek: usize, - out_k: usize, - out_rank: usize, - in_k: usize, - in_rank: usize, - ksk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, ) -> usize { - GLWECiphertext::bytes_of(module, basek, out_k, out_rank) - + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, out_rank, in_k, in_rank, ksk_k) + GLWECiphertext::bytes_of(module, basek, k_out, rank_out) + + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) } pub fn keyswitch_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - out_rank: usize, - ksk_k: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, ) -> usize { - Self::keyswitch_scratch_space(module, basek, out_k, out_rank, out_k, out_rank, ksk_k) + Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) } + // WARNING TODO: UPDATE pub fn external_product_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, - ggsw_k: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, rank: usize, ) -> usize { - let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); - let out_size: usize = div_ceil(basek, out_k); - let in_size: usize = div_ceil(basek, in_k); - let ggsw_size: usize = div_ceil(basek, ggsw_k); - let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); - let res_small: usize = GLWECiphertext::bytes_of(module, basek, out_k, rank); + let ggsw_size: usize = div_ceil(k_ggsw, basek); + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size); + let in_size: usize = div_ceil(div_ceil(k_in, basek), digits); + let ggsw_size: usize = div_ceil(k_ggsw, basek); + let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + + module.vmp_apply_tmp_bytes(ggsw_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); + let res_small: usize = module.bytes_of_vec_znx(rank + 1, ggsw_size); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); res_dft + (vmp | (res_small + normalize)) } @@ -111,11 +116,12 @@ impl GLWECiphertextFourier, FFT64> { pub fn external_product_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - ggsw_k: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, rank: usize, ) -> usize { - Self::external_product_scratch_space(module, basek, out_k, out_k, ggsw_k, rank) + Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) } } @@ -176,6 +182,18 @@ impl + AsRef<[u8]>> GLWECiphertextFourier assert_eq!(rhs.n(), module.n()); assert_eq!(self.n(), module.n()); assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertextFourier::external_product_scratch_space( + module, + self.basek(), + self.k(), + lhs.k(), + rhs.k(), + rhs.digits(), + rhs.rank(), + ) + ); } let cols: usize = rhs.rank() + 1; @@ -184,7 +202,22 @@ impl + AsRef<[u8]>> GLWECiphertextFourier let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); { - module.vmp_apply(&mut res_dft, &lhs.data, &rhs.data, scratch1); + let digits = rhs.digits(); + + (0..digits).for_each(|di| { + // (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + di) / digits); + + (0..cols).for_each(|col_i| { + module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); + }); + + if di == 0 { + module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2); + } else { + module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); + } + }); } // VMP result in high precision @@ -194,7 +227,7 @@ impl + AsRef<[u8]>> GLWECiphertextFourier let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); (0..cols).for_each(|i| { module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); - module.vec_znx_dft(&mut self.data, i, &res_small, i); + module.vec_znx_dft(1, 0, &mut self.data, i, &res_small, i); }); } diff --git a/core/src/glwe_packing.rs b/core/src/glwe_packing.rs index e4d6f9d..85aceb6 100644 --- a/core/src/glwe_packing.rs +++ b/core/src/glwe_packing.rs @@ -74,8 +74,8 @@ impl StreamPacker { } /// Number of scratch space bytes required to call [Self::add]. - pub fn scratch_space(module: &Module, basek: usize, ct_k: usize, atk_k: usize, rank: usize) -> usize { - pack_core_scratch_space(module, basek, ct_k, atk_k, rank) + pub fn scratch_space(module: &Module, basek: usize, ct_k: usize, k_ksk: usize, digits: usize, rank: usize) -> usize { + pack_core_scratch_space(module, basek, ct_k, k_ksk, digits, rank) } pub fn galois_elements(module: &Module) -> Vec { @@ -142,8 +142,8 @@ impl StreamPacker { } } -fn pack_core_scratch_space(module: &Module, basek: usize, ct_k: usize, atk_k: usize, rank: usize) -> usize { - combine_scratch_space(module, basek, ct_k, atk_k, rank) +fn pack_core_scratch_space(module: &Module, basek: usize, ct_k: usize, k_ksk: usize, digits: usize, rank: usize) -> usize { + combine_scratch_space(module, basek, ct_k, k_ksk, digits, rank) } fn pack_core, DataAK: AsRef<[u8]>>( @@ -203,10 +203,10 @@ fn pack_core, DataAK: AsRef<[u8]>>( } } -fn combine_scratch_space(module: &Module, basek: usize, ct_k: usize, atk_k: usize, rank: usize) -> usize { +fn combine_scratch_space(module: &Module, basek: usize, ct_k: usize, k_ksk: usize, digits: usize, rank: usize) -> usize { GLWECiphertext::bytes_of(module, basek, ct_k, rank) + (GLWECiphertext::rsh_scratch_space(module) - | GLWECiphertext::automorphism_scratch_space(module, basek, ct_k, ct_k, atk_k, rank)) + | GLWECiphertext::automorphism_scratch_space(module, basek, ct_k, ct_k, k_ksk, digits, rank)) } /// [combine] merges two ciphertexts together. diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe_plaintext.rs index 058c9a5..0ab3a46 100644 --- a/core/src/glwe_plaintext.rs +++ b/core/src/glwe_plaintext.rs @@ -37,14 +37,14 @@ impl + AsRef<[u8]>> SetMetaData for GLWEPlaintext> { pub fn alloc(module: &Module, basek: usize, k: usize) -> Self { Self { - data: module.new_vec_znx(1, div_ceil(basek, k)), + data: module.new_vec_znx(1, div_ceil(k, basek)), basek: basek, k, } } pub fn byte_of(module: &Module, basek: usize, k: usize) -> usize { - module.bytes_of_vec_znx(1, div_ceil(basek, k)) + module.bytes_of_vec_znx(1, div_ceil(k, basek)) } } diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index bc9baae..56d42b4 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -6,13 +6,29 @@ use crate::{GGLWECiphertext, GGSWCiphertext, GLWECiphertextFourier, GLWESecret, pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); impl GLWESwitchingKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn alloc( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> Self { GLWESwitchingKey(GGLWECiphertext::alloc( module, basek, k, rows, digits, rank_in, rank_out, )) } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize { + pub fn bytes_of( + module: &Module, + basek: usize, + k: usize, + rows: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + ) -> usize { GGLWECiphertext::, FFT64>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) } } @@ -87,53 +103,59 @@ impl GLWESwitchingKey, FFT64> { pub fn keyswitch_scratch_space( module: &Module, basek: usize, - out_k: usize, - out_rank: usize, - in_k: usize, - in_rank: usize, - ksk_k: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, ) -> usize { - let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, in_k, in_rank); - let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, out_rank); - let ksk: usize = GLWECiphertextFourier::keyswitch_scratch_space(module, basek, out_k, out_rank, in_k, in_rank, ksk_k); + let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank_in); + let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank_out); + let ksk: usize = + GLWECiphertextFourier::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out); tmp_in + tmp_out + ksk } pub fn keyswitch_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - out_rank: usize, - ksk_k: usize, + k_out: usize, + k_ksk: usize, + digits: usize, + rank: usize, ) -> usize { - let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, out_rank); - let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, basek, out_k, out_rank, ksk_k); + let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank); tmp + ksk } pub fn external_product_scratch_space( module: &Module, basek: usize, - out_k: usize, - in_k: usize, - ggsw_k: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, rank: usize, ) -> usize { - let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, in_k, rank); - let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); - let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, out_k, in_k, ggsw_k, rank); + let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank); + let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank); tmp_in + tmp_out + ggsw } pub fn external_product_inplace_scratch_space( module: &Module, basek: usize, - out_k: usize, - ggsw_k: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, rank: usize, ) -> usize { - let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); - let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, out_k, ggsw_k, rank); + let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank); + let ggsw: usize = + GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank); tmp + ggsw } } @@ -309,7 +331,7 @@ impl + AsRef<[u8]>> GLWESwitchingKey { } let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); - + println!("tmp: {}", tmp.size()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { self.get_row(module, row_j, col_i, &mut tmp); diff --git a/core/src/lib.rs b/core/src/lib.rs index b82f82a..8817f12 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -112,12 +112,12 @@ impl ScratchCore for Scratch { k: usize, rank: usize, ) -> (GLWECiphertext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx(module, rank + 1, div_ceil(basek, k)); + let (data, scratch) = self.tmp_vec_znx(module, rank + 1, div_ceil(k, basek)); (GLWECiphertext { data, basek, k }, scratch) } fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx(module, 1, div_ceil(basek, k)); + let (data, scratch) = self.tmp_vec_znx(module, 1, div_ceil(k, basek)); (GLWEPlaintext { data, basek, k }, scratch) } @@ -131,7 +131,13 @@ impl ScratchCore for Scratch { rank_in: usize, rank_out: usize, ) -> (GGLWECiphertext<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_mat_znx_dft(module, div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k)); + let (data, scratch) = self.tmp_mat_znx_dft( + module, + div_ceil(rows, digits), + rank_in, + rank_out + 1, + div_ceil(k, basek), + ); ( GGLWECiphertext { data: data, @@ -152,7 +158,13 @@ impl ScratchCore for Scratch { digits: usize, rank: usize, ) -> (GGSWCiphertext<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_mat_znx_dft(module, div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)); + let (data, scratch) = self.tmp_mat_znx_dft( + module, + div_ceil(rows, digits), + rank + 1, + rank + 1, + div_ceil(k, basek), + ); ( GGSWCiphertext { data, @@ -171,7 +183,7 @@ impl ScratchCore for Scratch { k: usize, rank: usize, ) -> (GLWECiphertextFourier<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, div_ceil(basek, k)); + let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, div_ceil(k, basek)); (GLWECiphertextFourier { data, basek, k }, scratch) } diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs index c7e9bf4..c0887c9 100644 --- a/core/src/tensor_key.rs +++ b/core/src/tensor_key.rs @@ -12,14 +12,16 @@ impl TensorKey, FFT64> { let mut keys: Vec, FFT64>> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { - keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, digits,1, rank)); + keys.push(GLWESwitchingKey::alloc( + module, basek, k, rows, digits, 1, rank, + )); }); Self { keys: keys } } pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits,1, rank) + pairs * GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits, 1, rank) } } diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/test_fft64/automorphism_key.rs index 784a2da..ebe855f 100644 --- a/core/src/test_fft64/automorphism_key.rs +++ b/core/src/test_fft64/automorphism_key.rs @@ -2,42 +2,76 @@ use backend::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps}; use sampling::source::Source; use crate::{ - AutomorphismKey, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GetRow, Infos, - test_fft64::gglwe::log2_std_noise_gglwe_product, + AutomorphismKey, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GetRow, Infos, div_ceil, + test_fft64::log2_std_noise_gglwe_product, }; #[test] fn automorphism() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 60; + let k_out: usize = 60; + let digits: usize = div_ceil(k_in, basek); + let sigma: f64 = 3.2; (1..4).for_each(|rank| { - println!("test automorphism rank: {}", rank); - test_automorphism(-1, 5, 12, 12, 60, 3.2, rank); + (2..digits + 1).for_each(|di| { + println!("test automorphism digits: {} rank: {}", di, rank); + let k_apply: usize = (digits + di) * basek; + test_automorphism(-1, 5, log_n, basek, di, k_in, k_out, k_apply, sigma, rank); + }); }); } #[test] fn automorphism_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = div_ceil(k_in, basek); + let sigma: f64 = 3.2; (1..4).for_each(|rank| { - println!("test automorphism_inplace rank: {}", rank); - test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank); + (2..digits + 1).for_each(|di| { + println!("test automorphism digits: {} rank: {}", di, rank); + let k_apply: usize = (digits + di) * basek; + test_automorphism_inplace(-1, 5, log_n, basek, di, k_in, k_apply, sigma, rank); + }); }); } -fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { +fn test_automorphism( + p0: i64, + p1: i64, + log_n: usize, + basek: usize, + digits: usize, + k_in: usize, + k_out: usize, + k_apply: usize, + sigma: f64, + rank: usize, +) { let module: Module = Module::::new(1 << log_n); - let rows = (k_ksk + basek - 1) / basek; - let mut auto_key_in: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, rank); - let mut auto_key_out: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, rank); - let mut auto_key_apply: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, rank); + let digits_in: usize = 1; + + let rows_in: usize = k_in / (basek * digits); + let rows_apply: usize = div_ceil(k_in, basek * digits); + + let mut auto_key_in: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_out: AutomorphismKey, FFT64> = + AutomorphismKey::alloc(&module, basek, k_out, rows_in, digits_in, rank); + let mut auto_key_apply: AutomorphismKey, FFT64> = + AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ksk) - | AutomorphismKey::automorphism_scratch_space(&module, basek, k_ksk, k_ksk, k_ksk, rank), + AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) + | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | AutomorphismKey::automorphism_scratch_space(&module, basek, k_out, k_in, k_apply, digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -68,8 +102,8 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key_out.automorphism(&module, &auto_key_in, &auto_key_apply, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ksk); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut sk_auto: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk @@ -88,26 +122,32 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, (0..auto_key_out.rank_in()).for_each(|col_i| { (0..auto_key_out.rows()).for_each(|row_i| { auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); - ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk.data, col_i); + + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk.data, + col_i, + ); let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, sigma * sigma, 0f64, rank as f64, - k_ksk, - k_ksk, + k_in, + k_apply, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want @@ -116,21 +156,36 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, }); } -fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { +fn test_automorphism_inplace( + p0: i64, + p1: i64, + log_n: usize, + basek: usize, + digits: usize, + k_in: usize, + k_apply: usize, + sigma: f64, + rank: usize, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k_ksk + basek - 1) / basek; - let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, rank); - let mut auto_key_apply: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, rank); + let digits_in: usize = 1; + + let rows_in: usize = k_in / (basek * digits); + let rows_apply: usize = div_ceil(k_in, basek * digits); + + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_apply: AutomorphismKey, FFT64> = + AutomorphismKey::alloc(&module, basek, k_apply, rows_apply, digits, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ksk) - | AutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_ksk, k_ksk, rank), + AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_apply, rank) + | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_in) + | AutomorphismKey::automorphism_inplace_scratch_space(&module, basek, k_in, k_apply, digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -161,8 +216,8 @@ fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ksk); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_in, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); let mut sk_auto: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk @@ -183,24 +238,30 @@ fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft); ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk.data, col_i); + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk.data, + col_i, + ); let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, sigma * sigma, 0f64, rank as f64, - k_ksk, - k_ksk, + k_in, + k_apply, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index 1f722ba..aabce6e 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -2,30 +2,58 @@ use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToMut, ScratchO use sampling::source::Source; use crate::{ - GGSWCiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, - test_fft64::ggsw::noise_ggsw_product, + GGSWCiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, div_ceil, + test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, }; #[test] fn encrypt_sk() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ksk: usize = 54; + let digits: usize = k_ksk / basek; (1..4).for_each(|rank_in| { (1..4).for_each(|rank_out| { - println!("test encrypt_sk rank_in rank_out: {} {}", rank_in, rank_out); - test_encrypt_sk(12, 8, 54, 3.2, rank_in, rank_out); + (1..digits + 1).for_each(|di| { + println!( + "test encrypt_sk digits: {} ranks: ({} {})", + di, rank_in, rank_out + ); + test_encrypt_sk(log_n, basek, k_ksk, di, rank_in, rank_out, 3.2); + }); }); }); } #[test] fn key_switch() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank_in_s0s1| { (1..4).for_each(|rank_out_s0s1| { (1..4).for_each(|rank_out_s1s2| { - println!( - "test key_switch : ({},{},{})", - rank_in_s0s1, rank_out_s0s1, rank_out_s1s2 - ); - test_key_switch(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1, rank_out_s1s2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + println!( + "test key_switch digits: {} ranks: ({},{},{})", + di, rank_in_s0s1, rank_out_s0s1, rank_out_s1s2 + ); + let k_out: usize = k_ksk; // Better capture noise. + test_key_switch( + log_n, + basek, + k_out, + k_in, + k_ksk, + di, + rank_in_s0s1, + rank_out_s0s1, + rank_out_s1s2, + 3.2, + ); + }) }) }); }); @@ -33,45 +61,82 @@ fn key_switch() { #[test] fn key_switch_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank_in_s0s1| { (1..4).for_each(|rank_out_s0s1| { - println!( - "test key_switch_inplace : ({},{})", - rank_in_s0s1, rank_out_s0s1 - ); - test_key_switch_inplace(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!( + "test key_switch_inplace digits: {} ranks: ({},{})", + di, rank_in_s0s1, rank_out_s0s1 + ); + test_key_switch_inplace( + log_n, + basek, + k_ct, + k_ksk, + di, + rank_in_s0s1, + rank_out_s0s1, + 3.2, + ); + }); }); }); } #[test] fn external_product() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank_in| { (1..4).for_each(|rank_out| { - println!("test external_product rank: {} {}", rank_in, rank_out); - test_external_product(12, 12, 60, 3.2, rank_in, rank_out); + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + println!( + "test external_product digits: {} ranks: ({} {})", + di, rank_in, rank_out + ); + let k_out: usize = k_in; // Better capture noise. + test_external_product( + log_n, basek, k_out, k_in, k_ggsw, di, rank_in, rank_out, 3.2, + ); + }); }); }); } #[test] fn external_product_inplace() { + let log_n: usize = 5; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank_in| { (1..4).for_each(|rank_out| { - println!( - "test external_product_inplace rank: {} {}", - rank_in, rank_out - ); - test_external_product_inplace(12, 12, 60, 3.2, rank_in, rank_out); + (1..digits).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + println!( + "test external_product_inplace digits: {} ranks: ({} {})", + di, rank_in, rank_out + ); + test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank_in, rank_out, 3.2); + }); }); }); } -fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in: usize, rank_out: usize) { +fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, digits: usize, rank_in: usize, rank_out: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows = (k_ksk + basek - 1) / basek; + let rows: usize = (k_ksk - digits * basek) / (digits * basek); - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_in, rank_out); + let mut ksk: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ksk); let mut source_xs: Source = Source::new([0u8; 32]); @@ -106,9 +171,15 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in (0..ksk.rows()).for_each(|row_i| { ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier); ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_in.data, col_i); + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits - 1) + row_i * digits, + &sk_in.data, + col_i, + ); let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + assert!((sigma - std_pt).abs() <= 0.5, "{} {}", sigma, std_pt); }); }); } @@ -116,21 +187,46 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in fn test_key_switch( log_n: usize, basek: usize, + k_out: usize, + k_in: usize, k_ksk: usize, - sigma: f64, + digits: usize, rank_in_s0s1: usize, rank_out_s0s1: usize, rank_out_s1s2: usize, + sigma: f64, ) { let module: Module = Module::::new(1 << log_n); - let rows = (k_ksk + basek - 1) / basek; + let rows: usize = div_ceil(k_in, basek * digits); + let digits_in: usize = 1; - let mut ct_gglwe_s0s1: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1); - let mut ct_gglwe_s1s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s1s2); - let mut ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s1s2); + let mut ct_gglwe_s0s1: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc( + &module, + basek, + k_in, + rows, + digits_in, + rank_in_s0s1, + rank_out_s0s1, + ); + let mut ct_gglwe_s1s2: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc( + &module, + basek, + k_ksk, + rows, + digits, + rank_out_s0s1, + rank_out_s1s2, + ); + let mut ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc( + &module, + basek, + k_out, + rows, + digits_in, + rank_in_s0s1, + rank_out_s1s2, + ); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -138,15 +234,16 @@ fn test_key_switch( let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_in_s0s1 | rank_out_s0s1) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ksk) + | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) | GLWESwitchingKey::keyswitch_scratch_space( &module, basek, - ct_gglwe_s0s2.k(), - ct_gglwe_s0s2.rank(), - ct_gglwe_s0s1.k(), - ct_gglwe_s0s1.rank(), - ct_gglwe_s1s2.k(), + k_out, + k_in, + k_ksk, + digits, + ct_gglwe_s1s2.rank_in(), + ct_gglwe_s1s2.rank_out(), ), ); @@ -185,31 +282,37 @@ fn test_key_switch( ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow()); let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s1s2); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ksk); + GLWECiphertextFourier::alloc(&module, basek, k_out, rank_out_s1s2); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); ct_glwe_dft.decrypt(&module, &mut pt, &sk2, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk0.data, col_i); + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk0.data, + col_i, + ); let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, sigma * sigma, 0f64, rank_out_s0s1 as f64, - k_ksk, + k_in, k_ksk, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 1.0, "{} {}", noise_have, noise_want @@ -218,38 +321,42 @@ fn test_key_switch( }); } -fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in_s0s1: usize, rank_out_s0s1: usize) { +fn test_key_switch_inplace( + log_n: usize, + basek: usize, + k_ct: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k_ksk + basek - 1) / basek; + let rows: usize = div_ceil(k_ct, basek * digits); + let digits_in: usize = 1; let mut ct_gglwe_s0s1: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1); + GLWESwitchingKey::alloc(&module, basek, k_ct, rows, digits_in, rank_in, rank_out); let mut ct_gglwe_s1s2: GLWESwitchingKey, FFT64> = - GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s0s1); + GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_out, rank_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out_s0s1) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ksk) - | GLWESwitchingKey::keyswitch_inplace_scratch_space( - &module, - basek, - ct_gglwe_s0s1.k(), - ct_gglwe_s0s1.rank(), - ct_gglwe_s1s2.k(), - ), + | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, rank_out), ); - let mut sk0: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in_s0s1); + let mut sk0: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); sk0.fill_ternary_prob(&module, 0.5, &mut source_xs); - let mut sk1: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out_s0s1); + let mut sk1: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); sk1.fill_ternary_prob(&module, 0.5, &mut source_xs); - let mut sk2: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out_s0s1); + let mut sk2: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); sk2.fill_ternary_prob(&module, 0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 @@ -279,32 +386,37 @@ fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, let ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = ct_gglwe_s0s1; - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s0s1); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ksk); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); ct_glwe_dft.decrypt(&module, &mut pt, &sk2, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk0.data, col_i); + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk0.data, + col_i, + ); let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, sigma * sigma, 0f64, - rank_out_s0s1 as f64, - k_ksk, + rank_out as f64, + k_ct, k_ksk, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 1.0, "{} {}", noise_have, noise_want @@ -313,14 +425,27 @@ fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, }); } -fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) { +fn test_external_product( + log_n: usize, + basek: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k + basek - 1) / basek; + let rows: usize = div_ceil(k_in, basek * digits); + let digits_in: usize = 1; - let mut ct_gglwe_in: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k, rows, rank_in, rank_out); - let mut ct_gglwe_out: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k, rows, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, rank_out); + let mut ct_gglwe_in: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_in, rows, digits_in, rank_in, rank_out); + let mut ct_gglwe_out: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_out, rows, digits_in, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank_out); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); @@ -329,17 +454,10 @@ fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_ let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k) - | GLWESwitchingKey::external_product_scratch_space( - &module, - basek, - ct_gglwe_out.k(), - ct_gglwe_in.k(), - ct_rgsw.k(), - rank_out, - ) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank_out), + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_in, rank_out) + | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | GLWESwitchingKey::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), ); let r: usize = 1; @@ -376,22 +494,8 @@ fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_ // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe_out.external_product(&module, &ct_gglwe_in, &ct_rgsw, scratch.borrow()); - scratch = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k) - | GLWESwitchingKey::external_product_scratch_space( - &module, - basek, - ct_gglwe_out.k(), - ct_gglwe_in.k(), - ct_rgsw.k(), - rank_out, - ) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank_out), - ); - - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank_out); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); (0..rank_in).for_each(|i| { module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} @@ -401,7 +505,14 @@ fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_ (0..ct_gglwe_out.rows()).for_each(|row_i| { ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); ct_glwe_dft.decrypt(&module, &mut pt, &sk_out, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_in.data, col_i); + + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk_in.data, + col_i, + ); let noise_have: f64 = pt.data.std(0, basek).log2(); @@ -414,7 +525,7 @@ fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_ let noise_want: f64 = noise_ggsw_product( module.n() as f64, - basek, + basek * digits, 0.5, var_msg, var_a0_err, @@ -422,12 +533,12 @@ fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_ var_gct_err_lhs, var_gct_err_rhs, rank_out as f64, - k, - k, + k_in, + k_ggsw, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 1.0, "{} {}", noise_have, noise_want @@ -436,13 +547,25 @@ fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_ }); } -fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) { +fn test_external_product_inplace( + log_n: usize, + basek: usize, + k_ct: usize, + k_ggsw: usize, + digits: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k + basek - 1) / basek; + let rows: usize = div_ceil(k_ct, basek * digits); - let mut ct_gglwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k, rows, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, rank_out); + let digits_in: usize = 1; + + let mut ct_gglwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_ct, rows, digits_in, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank_out); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); @@ -451,10 +574,10 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f6 let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k, rank_out) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k) - | GLWESwitchingKey::external_product_inplace_scratch_space(&module, basek, k, k, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank_out), + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ct, rank_out) + | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + | GLWESwitchingKey::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank_out), ); let r: usize = 1; @@ -491,8 +614,8 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f6 // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) ct_gglwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank_out); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); (0..rank_in).for_each(|i| { module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} @@ -502,7 +625,14 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f6 (0..ct_gglwe.rows()).for_each(|row_i| { ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft); ct_glwe_dft.decrypt(&module, &mut pt, &sk_out, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_in.data, col_i); + + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk_in.data, + col_i, + ); let noise_have: f64 = pt.data.std(0, basek).log2(); @@ -515,7 +645,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f6 let noise_want: f64 = noise_ggsw_product( module.n() as f64, - basek, + basek * digits, 0.5, var_msg, var_a0_err, @@ -523,12 +653,12 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f6 var_gct_err_lhs, var_gct_err_rhs, rank_out as f64, - k, - k, + k_ct, + k_ggsw, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 1.0, "{} {}", noise_have, noise_want @@ -536,61 +666,3 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f6 }); }); } - -pub(crate) fn var_noise_gglwe_product( - n: f64, - basek: usize, - var_xs: f64, - var_msg: f64, - var_a_err: f64, - var_gct_err_lhs: f64, - var_gct_err_rhs: f64, - rank_in: f64, - a_logq: usize, - b_logq: usize, -) -> f64 { - let a_logq: usize = a_logq.min(b_logq); - let a_cols: usize = (a_logq + basek - 1) / basek; - - let b_scale = 2.0f64.powi(b_logq as i32); - let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - - let base: f64 = (1 << (basek)) as f64; - let var_base: f64 = base * base / 12f64; - - // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) - // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs - let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); - noise += var_msg * var_a_err * a_scale * a_scale * n; - noise *= rank_in; - noise /= b_scale * b_scale; - noise -} - -pub(crate) fn log2_std_noise_gglwe_product( - n: f64, - basek: usize, - var_xs: f64, - var_msg: f64, - var_a_err: f64, - var_gct_err_lhs: f64, - var_gct_err_rhs: f64, - rank_in: f64, - a_logq: usize, - b_logq: usize, -) -> f64 { - let mut noise: f64 = var_noise_gglwe_product( - n, - basek, - var_xs, - var_msg, - var_a_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank_in, - a_logq, - b_logq, - ); - noise = noise.sqrt(); - noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] -} diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 0495431..1d262a7 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -7,72 +7,127 @@ use sampling::source::Source; use crate::{ GGSWCiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GLWESwitchingKey, GetRow, Infos, TensorKey, automorphism::AutomorphismKey, + div_ceil, + test_fft64::{noise_ggsw_keyswitch, noise_ggsw_product}, }; -use super::gglwe::var_noise_gglwe_product; - #[test] fn encrypt_sk() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 54; + let digits: usize = k_ct / basek; (1..4).for_each(|rank| { - println!("test encrypt_sk rank: {}", rank); - test_encrypt_sk(11, 8, 54, 3.2, rank); + (1..digits + 1).for_each(|di| { + println!("test encrypt_sk digits: {} rank: {}", di, rank); + test_encrypt_sk(log_n, basek, k_ct, di, rank, 3.2); + }); }); } #[test] fn keyswitch() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 54; + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank| { - println!("test keyswitch rank: {}", rank); - test_keyswitch(12, 15, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_tsk: usize = k_ksk; + println!("test keyswitch digits: {} rank: {}", di, rank); + let k_out: usize = k_ksk; // Better capture noise. + test_keyswitch(log_n, basek, k_out, k_in, k_ksk, k_tsk, di, rank, 3.2); + }); }); } #[test] fn keyswitch_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 54; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank| { - println!("test keyswitch_inplace rank: {}", rank); - test_keyswitch_inplace(12, 15, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + let k_tsk: usize = k_ksk; + println!("test keyswitch_inplace digits: {} rank: {}", di, rank); + test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, k_tsk, di, rank, 3.2); + }); }); } #[test] fn automorphism() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 54; + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank| { - println!("test automorphism rank: {}", rank); - test_automorphism(-5, 12, 15, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_tsk: usize = k_ksk; + println!("test automorphism rank: {}", rank); + let k_out: usize = k_ksk; // Better capture noise. + test_automorphism(-5, log_n, basek, k_out, k_in, k_ksk, k_tsk, di, rank, 3.2); + }); }); } #[test] fn automorphism_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 54; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank| { - println!("test automorphism_inplace rank: {}", rank); - test_automorphism_inplace(-5, 12, 15, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + let k_tsk: usize = k_ksk; + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(-5, log_n, basek, k_ct, k_ksk, k_tsk, di, rank, 3.2); + }); }); } #[test] fn external_product() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank| { - println!("test external_product rank: {}", rank); - test_external_product(12, 12, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + println!("test external_product digits: {} ranks: {}", di, rank); + let k_out: usize = k_in; // Better capture noise. + test_external_product(log_n, basek, k_in, k_out, k_ggsw, di, rank, 3.2); + }); }); } #[test] fn external_product_inplace() { + let log_n: usize = 5; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank| { - println!("test external_product rank: {}", rank); - test_external_product_inplace(12, 15, 60, rank, 3.2); + (1..digits).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); + }); }); } -fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) { +fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k + basek - 1) / basek; + let rows: usize = (k - digits * basek) / (digits * basek); - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, rank); + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, digits, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -107,11 +162,17 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize (0..ct.rank() + 1).for_each(|col_j| { (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0); + module.vec_znx_add_scalar_inplace( + &mut pt_want.data, + 0, + (digits - 1) + row_i * digits, + &pt_scalar, + 0, + ); // mul with sk[col_j-1] if col_j > 0 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0); + module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); @@ -124,23 +185,35 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); let std_pt: f64 = pt_have.data.std(0, basek) * (k as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + assert!((sigma - std_pt).abs() <= 0.5, "{} {}", sigma, std_pt); pt_want.data.zero(); }); }); } -fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { +fn test_keyswitch( + log_n: usize, + basek: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + k_tsk: usize, + digits: usize, + rank: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k + basek - 1) / basek; + let rows: usize = div_ceil(k_in, digits * basek); - let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, rank); - let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, rank); - let mut tsk: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k, rows, rank); - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k, rows, rank, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + let digits_in: usize = 1; + + let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_in, rows, digits_in, rank); + let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_out, rows, digits_in, rank); + let mut tsk: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -148,18 +221,12 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k, rank) + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) + | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_scratch_space( - &module, - basek, - ct_out.k(), - ct_in.k(), - ksk.k(), - tsk.k(), - rank, + &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), ); @@ -203,7 +270,7 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank); + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); @@ -213,7 +280,7 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) // mul with sk[col_j-1] if col_j > 0 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0); + module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); module.svp_apply_inplace(&mut pt_dft, 0, &sk_out.data_fourier, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); @@ -228,21 +295,22 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = noise_ggsw_keyswitch( module.n() as f64, - basek, + basek * digits, col_j, var_xs, 0f64, sigma * sigma, 0f64, rank as f64, - k, - k, + k_in, + k_ksk, + k_tsk, ); println!("{} {}", noise_have, noise_want); assert!( - (noise_have - noise_want).abs() <= 0.1, + noise_have < noise_want + 0.5, "{} {}", noise_have, noise_want @@ -253,15 +321,26 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) }); } -fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { +fn test_keyswitch_inplace( + log_n: usize, + basek: usize, + k_ct: usize, + k_ksk: usize, + k_tsk: usize, + digits: usize, + rank: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k + basek - 1) / basek; + let rows: usize = div_ceil(k_ct, digits * basek); - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, rank); - let mut tsk: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k, rows, rank); - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k, rows, rank, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + let digits_in: usize = 1; + + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows, digits_in, rank); + let mut tsk: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -269,11 +348,11 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sig let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k, rank) - | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), tsk.k(), rank), + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) + | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank) + | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; @@ -316,17 +395,23 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sig ct.keyswitch_inplace(&module, &ksk, &tsk, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank); + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); (0..ct.rank() + 1).for_each(|col_j| { (0..ct.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0); + module.vec_znx_add_scalar_inplace( + &mut pt_want.data, + 0, + (digits_in - 1) + row_i * digits_in, + &pt_scalar, + 0, + ); // mul with sk[col_j-1] if col_j > 0 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0); + module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); module.svp_apply_inplace(&mut pt_dft, 0, &sk_out.data_fourier, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); @@ -341,21 +426,22 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sig let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = noise_ggsw_keyswitch( module.n() as f64, - basek, + basek * digits, col_j, var_xs, 0f64, sigma * sigma, 0f64, rank as f64, - k, - k, + k_ct, + k_ksk, + k_tsk, ); println!("{} {}", noise_have, noise_want); assert!( - (noise_have - noise_want).abs() <= 0.1, + noise_have < noise_want + 0.5, "{} {}", noise_have, noise_want @@ -366,65 +452,30 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sig }); } -pub(crate) fn noise_ggsw_keyswitch( - n: f64, +fn test_automorphism( + p: i64, + log_n: usize, basek: usize, - col: usize, - var_xs: f64, - var_a_err: f64, - var_gct_err_lhs: f64, - var_gct_err_rhs: f64, - rank: f64, - a_logq: usize, - b_logq: usize, -) -> f64 { - let var_si_x_sj: f64 = n * var_xs * var_xs; - - // Initial KS for col = 0 - let mut noise: f64 = var_noise_gglwe_product( - n, - basek, - var_xs, - var_xs, - var_a_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank, - a_logq, - b_logq, - ); - - // Other GGSW reconstruction for col > 0 - if col > 0 { - noise += var_noise_gglwe_product( - n, - basek, - var_xs, - var_si_x_sj, - var_a_err + 1f64 / 12.0, - var_gct_err_lhs, - var_gct_err_rhs, - rank, - a_logq, - b_logq, - ); - noise += n * noise * var_xs * 0.5; - } - - noise = noise.sqrt(); - noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] -} - -fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + k_out: usize, + k_in: usize, + k_ksk: usize, + k_tsk: usize, + digits: usize, + rank: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k + basek - 1) / basek; + let rows: usize = div_ceil(k_in, basek * digits); + let rows_in: usize = k_in / (basek * digits); - let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, rank); - let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, rank); - let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k, rows, rank); - let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k, rows, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + let digits_in: usize = 1; + + let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_out, rows_in, digits_in, rank); + let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -432,18 +483,12 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k, rank) + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_in, rank) + | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) + | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_scratch_space( - &module, - basek, - ct_out.k(), - ct_in.k(), - auto_key.k(), - tensor_key.k(), - rank, + &module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), ); @@ -486,7 +531,7 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank); + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); @@ -496,7 +541,7 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, // mul with sk[col_j-1] if col_j > 0 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0); + module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); @@ -511,19 +556,20 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = noise_ggsw_keyswitch( module.n() as f64, - basek, + basek * digits, col_j, var_xs, 0f64, sigma * sigma, 0f64, rank as f64, - k, - k, + k_in, + k_ksk, + k_tsk, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + noise_have < noise_want + 0.5, "{} {}", noise_have, noise_want @@ -534,15 +580,27 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, }); } -fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { +fn test_automorphism_inplace( + p: i64, + log_n: usize, + basek: usize, + k_ct: usize, + k_ksk: usize, + k_tsk: usize, + digits: usize, + rank: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k + basek - 1) / basek; + let rows: usize = div_ceil(k_ct, digits * basek); + let rows_in: usize = k_ct / (basek * digits); + let digits_in: usize = 1; - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k, rows, rank); - let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k, rows, rank); - let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k, rows, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows_in, digits_in, rank); + let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k_tsk, rows, digits, rank); + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -550,11 +608,11 @@ fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k, rank) - | TensorKey::generate_from_sk_scratch_space(&module, basek, k, rank) - | GGSWCiphertext::automorphism_inplace_scratch_space(&module, basek, ct.k(), auto_key.k(), tensor_key.k(), rank), + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ct, rank) + | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) + | TensorKey::generate_from_sk_scratch_space(&module, basek, k_tsk, rank) + | GGSWCiphertext::automorphism_inplace_scratch_space(&module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; @@ -596,7 +654,7 @@ fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k, rank); + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); @@ -606,7 +664,7 @@ fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: // mul with sk[col_j-1] if col_j > 0 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0); + module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); @@ -621,19 +679,20 @@ fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = noise_ggsw_keyswitch( module.n() as f64, - basek, + basek * digits, col_j, var_xs, 0f64, sigma * sigma, 0f64, rank as f64, - k, - k, + k_ct, + k_ksk, + k_tsk, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + noise_have <= noise_want + 0.5, "{} {}", noise_have, noise_want @@ -644,14 +703,27 @@ fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: }); } -fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { +fn test_external_product( + log_n: usize, + basek: usize, + k_in: usize, + k_out: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k_ggsw + basek - 1) / basek; + let rows: usize = div_ceil(k_in, basek * digits); + let rows_in: usize = k_in / (basek * digits); + let digits_in: usize = 1; - let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); - let mut ct_ggsw_lhs_in: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); - let mut ct_ggsw_lhs_out: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs_in: GGSWCiphertext, FFT64> = + GGSWCiphertext::alloc(&module, basek, k_in, rows_in, digits_in, rank); + let mut ct_ggsw_lhs_out: GGSWCiphertext, FFT64> = + GGSWCiphertext::alloc(&module, basek, k_out, rows_in, digits_in, rank); + let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -666,16 +738,9 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank) - | GGSWCiphertext::external_product_scratch_space( - &module, - basek, - ct_ggsw_lhs_out.k(), - ct_ggsw_lhs_in.k(), - ct_ggsw_rhs.k(), - rank, - ), + GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_out) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_scratch_space(&module, basek, k_out, k_in, k_ggsw, digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -703,20 +768,26 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ggsw, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ggsw); + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ggsw); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_ggsw_lhs, 0); + module.vec_znx_add_scalar_inplace( + &mut pt_want.data, + 0, + (digits_in - 1) + row_i * digits_in, + &pt_ggsw_lhs, + 0, + ); if col_j > 0 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0); + module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); @@ -738,7 +809,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, let noise_want: f64 = noise_ggsw_product( module.n() as f64, - basek, + basek * digits, 0.5, var_msg, var_a0_err, @@ -746,28 +817,33 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, var_gct_err_lhs, var_gct_err_rhs, rank as f64, - k_ggsw, + k_in, k_ggsw, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + noise_have <= noise_want + 0.5, "have: {} want: {}", noise_have, noise_want ); + println!("{} {}", noise_have, noise_want); + pt_want.data.zero(); }); }); } -fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { +fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k_ggsw + basek - 1) / basek; + let rows: usize = div_ceil(k_ct, digits * basek); + let rows_in: usize = k_ct / (basek * digits); + let digits_in: usize = 1; + + let mut ct_ggsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ct, rows_in, digits_in, rank); + let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); - let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); - let mut ct_ggsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -782,10 +858,9 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k, rank) - | GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k, rank) - | GGSWCiphertext::external_product_inplace_scratch_space(&module, basek, ct_ggsw_lhs.k(), ct_ggsw_rhs.k(), rank), + GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_inplace_scratch_space(&module, basek, k_ct, k_ggsw, digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -813,20 +888,26 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow()); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ggsw, rank); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ggsw); + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size()); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ggsw); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); (0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| { (0..ct_ggsw_lhs.rows()).for_each(|row_i| { - module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_ggsw_lhs, 0); + module.vec_znx_add_scalar_inplace( + &mut pt_want.data, + 0, + (digits_in - 1) + row_i * digits_in, + &pt_ggsw_lhs, + 0, + ); if col_j > 0 { - module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0); + module.vec_znx_dft(1, 0, &mut pt_dft, 0, &pt_want.data, 0); module.svp_apply_inplace(&mut pt_dft, 0, &sk.data_fourier, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow()); @@ -848,7 +929,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank let noise_want: f64 = noise_ggsw_product( module.n() as f64, - basek, + basek * digits, 0.5, var_msg, var_a0_err, @@ -856,12 +937,12 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank var_gct_err_lhs, var_gct_err_rhs, rank as f64, - k_ggsw, + k_ct, k_ggsw, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + noise_have <= noise_want + 0.5, "have: {} want: {}", noise_have, noise_want @@ -871,34 +952,3 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank }); }); } -pub(crate) fn noise_ggsw_product( - n: f64, - basek: usize, - var_xs: f64, - var_msg: f64, - var_a0_err: f64, - var_a1_err: f64, - var_gct_err_lhs: f64, - var_gct_err_rhs: f64, - rank: f64, - a_logq: usize, - b_logq: usize, -) -> f64 { - let a_logq: usize = a_logq.min(b_logq); - let a_cols: usize = (a_logq + basek - 1) / basek; - - let b_scale = 2.0f64.powi(b_logq as i32); - let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - - let base: f64 = (1 << (basek)) as f64; - let var_base: f64 = base * base / 12f64; - - // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) - // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs - let mut noise: f64 = (rank + 1.0) * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); - noise += var_msg * var_a0_err * a_scale * a_scale * n; - noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs * rank; - noise = noise.sqrt(); - noise /= b_scale; - noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] -} diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs index 2a2c2b9..bc1ca9e 100644 --- a/core/src/test_fft64/glwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -8,88 +8,140 @@ use sampling::source::Source; use crate::{ GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, automorphism::AutomorphismKey, + div_ceil, keyswitch_key::GLWESwitchingKey, - test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product}, + test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, }; #[test] fn encrypt_sk() { + let log_n: usize = 8; (1..4).for_each(|rank| { println!("test encrypt_sk rank: {}", rank); - test_encrypt_sk(11, 8, 54, 30, 3.2, rank); + test_encrypt_sk(log_n, 8, 54, 30, 3.2, rank); }); } #[test] fn encrypt_zero_sk() { + let log_n: usize = 8; (1..4).for_each(|rank| { println!("test encrypt_zero_sk rank: {}", rank); - test_encrypt_zero_sk(11, 8, 64, 3.2, rank); + test_encrypt_zero_sk(log_n, 8, 64, 3.2, rank); }); } #[test] fn encrypt_pk() { + let log_n: usize = 8; (1..4).for_each(|rank| { println!("test encrypt_pk rank: {}", rank); - test_encrypt_pk(11, 8, 64, 64, 3.2, rank) + test_encrypt_pk(log_n, 8, 64, 64, 3.2, rank) }); } #[test] fn keyswitch() { - (1..4).for_each(|in_rank| { - (1..4).for_each(|out_rank| { - println!("test keyswitch in_rank: {} out_rank: {}", in_rank, out_rank); - test_keyswitch(12, 12, 60, 45, 60, in_rank, out_rank, 3.2); + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = div_ceil(k_in, basek); + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_out: usize = k_ksk; // better capture noise + println!( + "test keyswitch digits: {} rank_in: {} rank_out: {}", + di, rank_in, rank_out + ); + test_keyswitch(log_n, basek, k_out, k_in, k_ksk, di, rank_in, rank_out, 3.2); + }) }); }); } #[test] fn keyswitch_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 45; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank| { - println!("test keyswitch_inplace rank: {}", rank); - test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!("test keyswitch_inplace digits: {} rank: {}", di, rank); + test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); + }); }); } #[test] fn external_product() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank| { - println!("test external_product rank: {}", rank); - test_external_product(12, 12, 60, 45, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + let k_out: usize = k_ggsw; // Better capture noise + println!("test external_product digits: {} rank: {}", di, rank); + test_external_product(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); + }); }); } #[test] fn external_product_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank| { - println!("test external_product rank: {}", rank); - test_external_product_inplace(12, 15, 60, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); + }); }); } #[test] fn automorphism_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank| { - println!("test automorphism_inplace rank: {}", rank); - test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!("test automorphism_inplace digits: {} rank: {}", di, rank); + test_automorphism_inplace(log_n, basek, -5, k_ct, k_ksk, di, rank, 3.2); + }); }); } #[test] fn automorphism() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank| { - println!("test automorphism rank: {}", rank); - test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_out: usize = k_ksk; // Better capture noise. + println!("test automorphism digits: {} rank: {}", di, rank); + test_automorphism(log_n, basek, -5, k_out, k_in, k_ksk, di, rank, 3.2); + }) }); } -fn test_encrypt_sk(log_n: usize, basek: usize, ct_k: usize, k_pt: usize, sigma: f64, rank: usize) { +fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_pt); let mut source_xs: Source = Source::new([0u8; 32]); @@ -144,10 +196,10 @@ fn test_encrypt_sk(log_n: usize, basek: usize, ct_k: usize, k_pt: usize, sigma: }); } -fn test_encrypt_zero_sk(log_n: usize, basek: usize, ct_k: usize, sigma: f64, rank: usize) { +fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); @@ -156,11 +208,11 @@ fn test_encrypt_zero_sk(log_n: usize, basek: usize, ct_k: usize, sigma: f64, ran let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, ct_k, rank); + let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertextFourier::decrypt_scratch_space(&module, basek, ct_k) - | GLWECiphertextFourier::encrypt_sk_scratch_space(&module, basek, ct_k, rank), + GLWECiphertextFourier::decrypt_scratch_space(&module, basek, k_ct) + | GLWECiphertextFourier::encrypt_sk_scratch_space(&module, basek, k_ct, rank), ); ct_dft.encrypt_zero_sk( @@ -173,14 +225,14 @@ fn test_encrypt_zero_sk(log_n: usize, basek: usize, ct_k: usize, sigma: f64, ran ); ct_dft.decrypt(&module, &mut pt, &sk, scratch.borrow()); - assert!((sigma - pt.data.std(0, basek) * (ct_k as f64).exp2()) <= 0.2); + assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); } -fn test_encrypt_pk(log_n: usize, basek: usize, ct_k: usize, k_pk: usize, sigma: f64, rank: usize) { +fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -205,7 +257,7 @@ fn test_encrypt_pk(log_n: usize, basek: usize, ct_k: usize, k_pk: usize, sigma: .iter_mut() .for_each(|x| *x = source_xa.next_i64() & 0); - pt_want.data.encode_vec_i64(0, basek, ct_k, &data_want, 10); + pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10); ct.encrypt_pk( &module, @@ -217,14 +269,14 @@ fn test_encrypt_pk(log_n: usize, basek: usize, ct_k: usize, k_pk: usize, sigma: scratch.borrow(), ); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); ct.decrypt(&module, &mut pt_have, &sk, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); let noise_have: f64 = pt_want.data.std(0, basek).log2(); - let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (ct_k as f64); + let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); assert!( (noise_have - noise_want).abs() < 0.2, @@ -237,50 +289,53 @@ fn test_encrypt_pk(log_n: usize, basek: usize, ct_k: usize, k_pk: usize, sigma: fn test_keyswitch( log_n: usize, basek: usize, - k_keyswitch: usize, - ct_k_in: usize, - ct_k_out: usize, - in_rank: usize, - out_rank: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, + rank_in: usize, + rank_out: usize, sigma: f64, ) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (ct_k_in + basek - 1) / basek; - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_keyswitch, rows, in_rank, out_rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k_in, in_rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k_out, out_rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k_out); + let rows: usize = div_ceil(k_in, basek * digits); + + let mut ksk: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - // Random input plaintext pt_want .data .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), out_rank) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_out) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( &module, basek, ct_out.k(), - out_rank, ct_in.k(), - in_rank, ksk.k(), + digits, + rank_in, + rank_out, ), ); - let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, in_rank); + let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_in); sk_in.fill_ternary_prob(&module, 0.5, &mut source_xs); - let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, out_rank); + let mut sk_out: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank_out); sk_out.fill_ternary_prob(&module, 0.5, &mut source_xs); ksk.generate_from_sk( @@ -304,7 +359,6 @@ fn test_keyswitch( ); ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow()); - ct_out.decrypt(&module, &mut pt_have, &sk_out, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); @@ -312,39 +366,41 @@ fn test_keyswitch( let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, sigma * sigma, 0f64, - in_rank as f64, - ct_k_in, - k_keyswitch, + rank_in as f64, + k_in, + k_ksk, ); + println!("{} vs. {}", noise_have, noise_want); + assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want ); } -fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, ct_k: usize, rank: usize, sigma: f64) { +fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (ct_k + basek - 1) / basek; - let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); + let rows: usize = div_ceil(k_ct, basek * digits); + + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - // Random input plaintext pt_want .data .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); @@ -353,7 +409,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, ct_k: usize, GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ct_grlwe.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_glwe.k(), rank, ct_grlwe.k()), + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_grlwe.k(), digits, rank), ); let mut sk0: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -391,19 +447,19 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, ct_k: usize, let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, sigma * sigma, 0f64, rank as f64, - ct_k, + k_ct, k_ksk, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want @@ -414,20 +470,22 @@ fn test_automorphism( log_n: usize, basek: usize, p: i64, - k_autokey: usize, - ct_k_in: usize, - ct_k_out: usize, + k_out: usize, + k_in: usize, + k_ksk: usize, + digits: usize, rank: usize, sigma: f64, ) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (ct_k_in + basek - 1) / basek; - let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_autokey, rows, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k_out); + let rows: usize = div_ceil(k_in, basek * digits); + + let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -441,7 +499,15 @@ fn test_automorphism( AutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | GLWECiphertext::automorphism_scratch_space(&module, basek, ct_out.k(), ct_in.k(), autokey.k(), rank), + | GLWECiphertext::automorphism_scratch_space( + &module, + basek, + ct_out.k(), + ct_in.k(), + autokey.k(), + digits, + rank, + ), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -479,39 +545,48 @@ fn test_automorphism( let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, sigma * sigma, 0f64, rank as f64, - ct_k_in, - k_autokey, + k_in, + k_ksk, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want ); } -fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, ct_k: usize, rank: usize, sigma: f64) { +fn test_automorphism_inplace( + log_n: usize, + basek: usize, + p: i64, + k_ct: usize, + k_ksk: usize, + digits: usize, + rank: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (ct_k + basek - 1) / basek; - let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_autokey, rows, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); + let rows: usize = div_ceil(k_ct, basek * digits); + + let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - // Random input plaintext pt_want .data .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); @@ -520,7 +595,7 @@ fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usiz AutomorphismKey::generate_from_sk_scratch_space(&module, basek, autokey.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::automorphism_inplace_scratch_space(&module, basek, ct.k(), autokey.k(), rank), + | GLWECiphertext::automorphism_inplace_scratch_space(&module, basek, ct.k(), autokey.k(), digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -555,36 +630,45 @@ fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usiz let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, sigma * sigma, 0f64, rank as f64, - ct_k, - k_autokey, + k_ct, + k_ksk, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want ); } -fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, ct_k_in: usize, ct_k_out: usize, rank: usize, sigma: f64) { +fn test_external_product( + log_n: usize, + basek: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (ct_k_in + basek - 1) / basek; + let rows: usize = div_ceil(k_in, digits * basek); - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k_in, rank); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k_out, rank); + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -611,6 +695,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, ct_k_in: usi ct_glwe_out.k(), ct_glwe_in.k(), ct_ggsw.k(), + digits, rank, ), ); @@ -657,7 +742,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, ct_k_in: usi let noise_want: f64 = noise_ggsw_product( module.n() as f64, - basek, + basek * digits, 0.5, var_msg, var_a0_err, @@ -665,27 +750,27 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, ct_k_in: usi var_gct_err_lhs, var_gct_err_rhs, rank as f64, - ct_k_in, + k_in, k_ggsw, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want ); } -fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, ct_k: usize, rank: usize, sigma: f64) { +fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (ct_k + basek - 1) / basek; + let rows: usize = div_ceil(k_ct, digits * basek); - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k, rank); + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -706,7 +791,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, ct_k GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), rank), + | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -751,7 +836,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, ct_k let noise_want: f64 = noise_ggsw_product( module.n() as f64, - basek, + basek * digits, 0.5, var_msg, var_a0_err, @@ -759,12 +844,12 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, ct_k var_gct_err_lhs, var_gct_err_rhs, rank as f64, - ct_k, + k_ct, k_ggsw, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index ea727ce..a1ff3a3 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -1,67 +1,101 @@ use crate::{ - GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, - test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product}, + GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, div_ceil, + test_fft64::{log2_std_noise_gglwe_product, noise_ggsw_product}, }; use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; use sampling::source::Source; #[test] fn keyswitch() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank_in| { (1..4).for_each(|rank_out| { - println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out); - test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + println!( + "test keyswitch digits: {} rank_in: {} rank_out: {}", + di, rank_in, rank_out + ); + let k_out: usize = k_ksk; // Better capture noise. + test_keyswitch(log_n, basek, k_in, k_out, k_ksk, di, rank_in, rank_out, 3.2); + }) }); }); } #[test] fn keyswitch_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 45; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank| { - println!("test keyswitch_inplace rank: {}", rank); - test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!("test keyswitch_inplace digits: {} rank: {}", di, rank); + test_keyswitch_inplace(log_n, basek, k_ct, k_ksk, di, rank, 3.2); + }); }); } #[test] fn external_product() { + let log_n: usize = 8; + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = div_ceil(k_in, basek); (1..4).for_each(|rank| { - println!("test external_product rank: {}", rank); - test_external_product(12, 12, 60, 45, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + let k_out: usize = k_ggsw; // Better capture noise. + test_external_product(log_n, basek, k_out, k_in, k_ggsw, di, rank, 3.2); + }); }); } #[test] fn external_product_inplace() { + let log_n: usize = 8; + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = div_ceil(k_ct, basek); (1..4).for_each(|rank| { - println!("test external_product rank: {}", rank); - test_external_product_inplace(12, 15, 60, 60, rank, 3.2); + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + println!("test external_product digits: {} rank: {}", di, rank); + test_external_product_inplace(log_n, basek, k_ct, k_ggsw, di, rank, 3.2); + }); }); } fn test_keyswitch( log_n: usize, basek: usize, + k_in: usize, + k_out: usize, k_ksk: usize, - k_ct_in: usize, - k_ct_out: usize, + digits: usize, rank_in: usize, rank_out: usize, sigma: f64, ) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k_ct_in + basek - 1) / basek; + let rows: usize = div_ceil(k_in, basek * digits); - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_in, rank_out); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_in, rank_in); - let mut ct_glwe_dft_in: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_ct_in, rank_in); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_out, rank_out); + let mut ksk: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank_in); + let mut ct_glwe_dft_in: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_in, rank_in); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank_out); let mut ct_glwe_dft_out: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::alloc(&module, basek, k_ct_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct_out); + GLWECiphertextFourier::alloc(&module, basek, k_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -74,16 +108,17 @@ fn test_keyswitch( let mut scratch: ScratchOwned = ScratchOwned::new( GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out) - | GLWECiphertext::decrypt_scratch_space(&module, basek, k_ct_out) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct_in) + | GLWECiphertext::decrypt_scratch_space(&module, basek, k_out) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_in) | GLWECiphertextFourier::keyswitch_scratch_space( &module, basek, ct_glwe_out.k(), - rank_out, - ct_glwe_in.k(), - rank_in, ksk.k(), + ct_glwe_in.k(), + digits, + rank_in, + rank_out, ), ); @@ -124,30 +159,31 @@ fn test_keyswitch( let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, sigma * sigma, 0f64, rank_in as f64, - k_ct_in, + k_in, k_ksk, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want ); } -fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) { +fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k_ct + basek - 1) / basek; - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank, rank); + let rows: usize = div_ceil(k_ct, basek * digits); + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, digits, rank, rank); let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); @@ -166,7 +202,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), rank), + | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), digits, rank), ); let mut sk_in: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -206,7 +242,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek, + basek * digits, 0.5, 0.5, 0f64, @@ -218,26 +254,35 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want ); } -fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) { +fn test_external_product( + log_n: usize, + basek: usize, + k_out: usize, + k_in: usize, + k_ggsw: usize, + digits: usize, + rank: usize, + sigma: f64, +) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k_ct_in + basek - 1) / basek; + let rows: usize = div_ceil(k_in, digits * basek); - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_out, rank); - let mut ct_in_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct_in, rank); - let mut ct_out_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct_out, rank); + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_out, rank); + let mut ct_in_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_in, rank); + let mut ct_out_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct_in); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -259,7 +304,15 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) - | GLWECiphertextFourier::external_product_scratch_space(&module, basek, ct_out.k(), ct_in.k(), ct_ggsw.k(), rank), + | GLWECiphertextFourier::external_product_scratch_space( + &module, + basek, + ct_out.k(), + ct_in.k(), + ct_ggsw.k(), + digits, + rank, + ), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -305,7 +358,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi let noise_want: f64 = noise_ggsw_product( module.n() as f64, - basek, + basek * digits, 0.5, var_msg, var_a0_err, @@ -313,23 +366,23 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi var_gct_err_lhs, var_gct_err_rhs, rank as f64, - k_ct_in, + k_in, k_ggsw, ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want ); } -fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) { +fn test_external_product_inplace(log_n: usize, basek: usize, k_ct: usize, k_ggsw: usize, digits: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k_ct + basek - 1) / basek; + let rows: usize = div_ceil(k_ct, digits * basek); - let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, digits, rank); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); @@ -356,7 +409,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertextFourier::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), rank), + | GLWECiphertextFourier::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -402,7 +455,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct let noise_want: f64 = noise_ggsw_product( module.n() as f64, - basek, + basek * digits, 0.5, var_msg, var_a0_err, @@ -415,9 +468,11 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct ); assert!( - (noise_have - noise_want).abs() <= 0.1, + (noise_have - noise_want).abs() <= 0.5, "{} {}", noise_have, noise_want ); + + println!("{} {}", noise_have, noise_want); } diff --git a/core/src/test_fft64/glwe_packing.rs b/core/src/test_fft64/glwe_packing.rs index 6ee38b2..5fa82ee 100644 --- a/core/src/test_fft64/glwe_packing.rs +++ b/core/src/test_fft64/glwe_packing.rs @@ -1,4 +1,4 @@ -use crate::{AutomorphismKey, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, StreamPacker}; +use crate::{AutomorphismKey, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWESecret, StreamPacker, div_ceil}; use std::collections::HashMap; use backend::{Encoding, FFT64, Module, ScratchOwned, Stats}; @@ -14,24 +14,26 @@ fn packing() { let mut source_xa: Source = Source::new([0u8; 32]); let basek: usize = 18; - let ct_k: usize = 36; - let atk_k: usize = ct_k + basek; + let k_ct: usize = 36; let pt_k: usize = 18; let rank: usize = 3; - let rows: usize = (ct_k + basek - 1) / basek; let sigma: f64 = 3.2; + let digits: usize = 1; + let k_ksk: usize = k_ct + basek * digits; + + let rows: usize = div_ceil(k_ct, basek * digits); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_k) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct_k) - | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, atk_k, rank) - | StreamPacker::scratch_space(&module, basek, ct_k, atk_k, rank), + GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct) + | GLWECiphertext::decrypt_scratch_space(&module, basek, k_ct) + | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_ksk, rank) + | StreamPacker::scratch_space(&module, basek, k_ct, k_ksk, digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); sk.fill_ternary_prob(&module, 0.5, &mut source_xs); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); let mut data: Vec = vec![0i64; module.n()]; data.iter_mut().enumerate().for_each(|(i, x)| { *x = i as i64; @@ -42,7 +44,7 @@ fn packing() { let mut auto_keys: HashMap, FFT64>> = HashMap::new(); gal_els.iter().for_each(|gal_el| { - let mut key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, atk_k, rows, rank); + let mut key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_ksk, rows, digits, rank); key.generate_from_sk( &module, *gal_el, @@ -57,9 +59,9 @@ fn packing() { let log_batch: usize = 0; - let mut packer: StreamPacker = StreamPacker::new(&module, log_batch, basek, ct_k, rank); + let mut packer: StreamPacker = StreamPacker::new(&module, log_batch, basek, k_ct, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, ct_k, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct, rank); ct.encrypt_sk( &module, @@ -102,7 +104,7 @@ fn packing() { packer.flush(&module, &mut res, &auto_keys, scratch.borrow()); packer.reset(); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, ct_k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ct); res.iter().enumerate().for_each(|(i, res_i)| { let mut data: Vec = vec![0i64; module.n()]; @@ -124,7 +126,7 @@ fn packing() { let noise_have = pt.data.std(0, basek).log2(); // println!("noise_have: {}", noise_have); assert!( - noise_have < -((ct_k - basek) as f64), + noise_have < -((k_ct - basek) as f64), "noise: {}", noise_have ); diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index 0ccf371..6fcecb1 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -6,3 +6,143 @@ mod glwe_fourier; mod glwe_packing; mod tensor_key; mod trace; + +pub(crate) fn var_noise_gglwe_product( + n: f64, + basek: usize, + var_xs: f64, + var_msg: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank_in: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let a_logq: usize = a_logq.min(b_logq); + let a_cols: usize = (a_logq + basek - 1) / basek; + + let b_scale: f64 = (b_logq as f64).exp2(); + let a_scale: f64 = ((b_logq - a_logq) as f64).exp2(); + + let base: f64 = (basek as f64).exp2(); + let var_base: f64 = base * base / 12f64; + + // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) + // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs + let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a_err * a_scale * a_scale * n; + noise *= rank_in; + noise /= b_scale * b_scale; + noise +} + +pub(crate) fn log2_std_noise_gglwe_product( + n: f64, + basek: usize, + var_xs: f64, + var_msg: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank_in: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let mut noise: f64 = var_noise_gglwe_product( + n, + basek, + var_xs, + var_msg, + var_a_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_in, + a_logq, + b_logq, + ); + noise = noise.sqrt(); + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} + +pub(crate) fn noise_ggsw_product( + n: f64, + basek: usize, + var_xs: f64, + var_msg: f64, + var_a0_err: f64, + var_a1_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank: f64, + k_in: usize, + k_ggsw: usize, +) -> f64 { + let a_logq: usize = k_in.min(k_ggsw); + let a_cols: usize = (a_logq + basek - 1) / basek; + + let b_scale: f64 = (k_ggsw as f64).exp2(); + let a_scale: f64 = ((k_ggsw - a_logq) as f64).exp2(); + + let base: f64 = (basek as f64).exp2(); + let var_base: f64 = base * base / 12f64; + + // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) + // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs + let mut noise: f64 = (rank + 1.0) * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a0_err * a_scale * a_scale * n; + noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs * rank; + noise = noise.sqrt(); + noise /= b_scale; + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} + +pub(crate) fn noise_ggsw_keyswitch( + n: f64, + basek: usize, + col: usize, + var_xs: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank: f64, + k_ct: usize, + k_ksk: usize, + k_tsk: usize, +) -> f64 { + let var_si_x_sj: f64 = n * var_xs * var_xs; + + // Initial KS for col = 0 + let mut noise: f64 = var_noise_gglwe_product( + n, + basek, + var_xs, + var_xs, + var_a_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank, + k_ct, + k_ksk, + ); + + // Other GGSW reconstruction for col > 0 + if col > 0 { + noise += var_noise_gglwe_product( + n, + basek, + var_xs, + var_si_x_sj, + var_a_err + 1f64 / 12.0, + var_gct_err_lhs, + var_gct_err_rhs, + rank, + k_ct, + k_tsk, + ); + noise += n * noise * var_xs * 0.5; + } + + noise = noise.sqrt(); + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} diff --git a/core/src/test_fft64/tensor_key.rs b/core/src/test_fft64/tensor_key.rs index c6b9b26..579fec4 100644 --- a/core/src/test_fft64/tensor_key.rs +++ b/core/src/test_fft64/tensor_key.rs @@ -5,18 +5,19 @@ use crate::{GLWECiphertextFourier, GLWEPlaintext, GLWESecret, GetRow, Infos, Ten #[test] fn encrypt_sk() { + let log_n: usize = 8; (1..4).for_each(|rank| { println!("test encrypt_sk rank: {}", rank); - test_encrypt_sk(12, 16, 54, 3.2, rank); + test_encrypt_sk(log_n, 16, 54, 3.2, rank); }); } fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) { let module: Module = Module::::new(1 << log_n); - let rows: usize = (k + basek - 1) / basek; + let rows: usize = k / basek; - let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k, rows, rank); + let mut tensor_key: TensorKey, FFT64> = TensorKey::alloc(&module, basek, k, rows, 1, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -65,7 +66,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize ct_glwe_fourier.decrypt(&module, &mut pt, &sk, scratch.borrow()); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2(); - assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + assert!((sigma - std_pt).abs() <= 0.5, "{} {}", sigma, std_pt); }); }); }) diff --git a/core/src/test_fft64/trace.rs b/core/src/test_fft64/trace.rs index 5025ba6..cec4533 100644 --- a/core/src/test_fft64/trace.rs +++ b/core/src/test_fft64/trace.rs @@ -3,13 +3,14 @@ use std::collections::HashMap; use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps, ZnxView, ZnxViewMut}; use sampling::source::Source; -use crate::{AutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, test_fft64::gglwe::var_noise_gglwe_product}; +use crate::{AutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, div_ceil, test_fft64::var_noise_gglwe_product}; #[test] fn trace_inplace() { + let log_n: usize = 8; (1..4).for_each(|rank| { println!("test trace_inplace rank: {}", rank); - test_trace_inplace(11, 8, 54, 3.2, rank); + test_trace_inplace(log_n, 8, 54, 3.2, rank); }); } @@ -18,7 +19,8 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us let k_autokey: usize = k + basek; - let rows: usize = (k + basek - 1) / basek; + let digits: usize = 1; + let rows: usize = div_ceil(k, digits * basek); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k); @@ -32,7 +34,7 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()) | AutomorphismKey::generate_from_sk_scratch_space(&module, basek, k_autokey, rank) - | GLWECiphertext::trace_inplace_scratch_space(&module, basek, ct.k(), k_autokey, rank), + | GLWECiphertext::trace_inplace_scratch_space(&module, basek, ct.k(), k_autokey, digits, rank), ); let mut sk: GLWESecret, FFT64> = GLWESecret::alloc(&module, rank); @@ -61,7 +63,7 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us let mut auto_keys: HashMap, FFT64>> = HashMap::new(); let gal_els: Vec = GLWECiphertext::trace_galois_elements(&module); gal_els.iter().for_each(|gal_el| { - let mut key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_autokey, rows, rank); + let mut key: AutomorphismKey, FFT64> = AutomorphismKey::alloc(&module, basek, k_autokey, rows, digits, rank); key.generate_from_sk( &module, *gal_el, @@ -102,5 +104,10 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us noise_want += module.n() as f64 * 1.0 / 12.0 * 0.5 * rank as f64 * (-2.0 * (k) as f64).exp2(); noise_want = noise_want.sqrt().log2(); - assert!((noise_have - noise_want).abs() < 1.0); + assert!( + (noise_have - noise_want).abs() < 1.0, + "{} > {}", + noise_have, + noise_want + ); } diff --git a/core/src/trace.rs b/core/src/trace.rs index 9414fbe..3c6a5bb 100644 --- a/core/src/trace.rs +++ b/core/src/trace.rs @@ -22,14 +22,22 @@ impl GLWECiphertext> { basek: usize, out_k: usize, in_k: usize, - atk_k: usize, + ksk_k: usize, + digits: usize, rank: usize, ) -> usize { - Self::automorphism_inplace_scratch_space(module, basek, out_k.min(in_k), atk_k, rank) + Self::automorphism_inplace_scratch_space(module, basek, out_k.min(in_k), ksk_k, digits, rank) } - pub fn trace_inplace_scratch_space(module: &Module, basek: usize, out_k: usize, atk_k: usize, rank: usize) -> usize { - Self::automorphism_inplace_scratch_space(module, basek, out_k, atk_k, rank) + pub fn trace_inplace_scratch_space( + module: &Module, + basek: usize, + out_k: usize, + ksk_k: usize, + digits: usize, + rank: usize, + ) -> usize { + Self::automorphism_inplace_scratch_space(module, basek, out_k, ksk_k, digits, rank) } }