diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index e8f913c..28c5961 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -82,11 +82,17 @@ impl GGSWCiphertext, FFT64> { + module.bytes_of_vec_znx_dft(rank + 1, size) } - pub(crate) fn expand_row_scratch_space(module: &Module, self_size: usize, tsk_size: usize, rank: usize) -> usize { - let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); + pub(crate) fn expand_row_scratch_space( + module: &Module, + self_size: usize, + tensor_key_size: usize, + rank: usize, + ) -> usize { + let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tensor_key_size); let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size); - let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size); - let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size); + let vmp: usize = + tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tensor_key_size); + let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tensor_key_size); let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm)) } @@ -107,13 +113,13 @@ impl GGSWCiphertext, FFT64> { out_size: usize, in_size: usize, ksk_size: usize, - tsk_size: usize, + tensor_key_size: usize, rank: usize, ) -> usize { let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size); let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, ksk_size, rank); - let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tsk_size, rank); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank); let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); res_znx + ci_dft + (ks | expand_rows | res_dft) } @@ -123,13 +129,17 @@ impl GGSWCiphertext, FFT64> { out_size: usize, in_size: usize, auto_key_size: usize, + tensor_key_size: usize, rank: usize, ) -> usize { - let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, auto_key_size); - let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, out_size); - let vmp: usize = - GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size); - tmp_dft + tmp_idft + vmp + GGSWCiphertext::keyswitch_scratch_space( + module, + out_size, + in_size, + auto_key_size, + tensor_key_size, + rank, + ) } pub fn external_product_scratch_space( @@ -379,15 +389,17 @@ where }) } - pub fn automorphism( + pub fn automorphism( &mut self, module: &Module, lhs: &GGSWCiphertext, - rhs: &AutomorphismKey, + auto_key: &AutomorphismKey, + tensor_key: &TensorKey, scratch: &mut Scratch, ) where MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { #[cfg(debug_assertions)] { @@ -400,48 +412,62 @@ where ); assert_eq!( self.rank(), - rhs.rank(), + auto_key.rank(), "ggsw_in rank: {} != auto_key rank: {}", self.rank(), - rhs.rank() + auto_key.rank() + ); + assert_eq!( + self.rank(), + tensor_key.rank(), + "ggsw_in rank: {} != tensor_key rank: {}", + self.rank(), + tensor_key.rank() ); }; + let cols: usize = self.rank() + 1; - let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); //TODO optimize - - let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_data, - basek: lhs.basek(), - k: lhs.k(), - }; - - let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, self.size()); - - let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: tmp_idft_data, + let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); + let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, basek: self.basek(), k: self.k(), }; - (0..cols).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_dft); - tmp_idft.keyswitch_from_fourier(module, &tmp_dft, &rhs.key, scratch2); + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); + + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + lhs.keyswitch_internal_col0(module, row_i, &mut res, &auto_key.key, scratch2); + + // Isolates DFT(AUTO(a[i])) + (0..cols).for_each(|col_i| { + // (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2) + module.vec_znx_automorphism_inplace(auto_key.p(), &mut res, col_i); + module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i); + }); + + self.set_row(module, row_i, 0, &ci_dft); + + // Generates + // + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + pi(M[i]), b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i])) + (1..cols).for_each(|col_j| { + self.expand_row(module, col_j, &mut res, &ci_dft, tensor_key, scratch2); + + let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); (0..cols).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i); + module.vec_znx_dft(&mut res_dft, i, &res, i); }); - self.set_row(module, row_j, col_i, &tmp_dft); - }); - }); - tmp_dft.data.zero(); - - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank() + 1).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_dft); - }); - }); + self.set_row(module, row_i, col_j, &res_dft); + }) + }) } pub fn external_product( diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 2ee8708..cf5e3c7 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -1,6 +1,6 @@ use base2k::{ - FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScalarZnxOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, + VecZnxBigOps, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, }; use sampling::source::Source; @@ -292,14 +292,22 @@ pub(crate) fn noise_ggsw_keyswitch( noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] } +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(-5, 12, 15, 60, rank, 3.2); + }); +} + fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { let module: Module = Module::::new(1 << log_n); let rows: usize = (k + basek - 1) / basek; let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); - let mut tsk: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); - let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank); + let mut tensor_key: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -311,44 +319,38 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, let mut scratch: ScratchOwned = ScratchOwned::new( GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) - | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) - | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) - | GGSWCiphertext::keyswitch_scratch_space( + | AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size()) + | GGSWCiphertext::automorphism_scratch_space( &module, ct_out.size(), ct_in.size(), - ksk.size(), - tsk.size(), + auto_key.size(), + tensor_key.size(), rank, ), ); let var_xs: f64 = 0.5; - let mut sk_in: SecretKey> = SecretKey::new(&module, rank); - sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_in_dft.dft(&module, &sk_in); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); - let mut sk_out: SecretKey> = SecretKey::new(&module, rank); - sk_out.fill_ternary_prob(var_xs, &mut source_xs); - - let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); - sk_out_dft.dft(&module, &sk_out); - - ksk.encrypt_sk( + auto_key.encrypt_sk( &module, - &sk_in, - &sk_out_dft, + p, + &sk, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - tsk.encrypt_sk( + tensor_key.encrypt_sk( &module, - &sk_out_dft, + &sk_dft, &mut source_xa, &mut source_xe, sigma, @@ -360,14 +362,16 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, ct_in.encrypt_sk( &module, &pt_scalar, - &sk_in_dft, + &sk_dft, &mut source_xa, &mut source_xe, sigma, scratch.borrow(), ); - ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); + ct_out.automorphism(&module, &ct_in, &auto_key, &tensor_key, scratch.borrow()); + + module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); @@ -380,14 +384,14 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); } ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); - ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);