diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index b397adb..a4de563 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -71,7 +71,7 @@ impl> ZnxView for VecZnx { type Scalar = i64; } -impl> VecZnx { +impl VecZnx> { pub fn rsh_scratch_space(n: usize) -> usize { n * std::mem::size_of::() } diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 4205637..9c3f32e 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -334,7 +334,7 @@ where VecZnx: VecZnxToRef, MatZnxDft: MatZnxDftToRef, { - Self::keyswitch_private(self, true, rhs.p(), module, lhs, &rhs.key, scratch); + Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, lhs, &rhs.key, scratch); } pub fn automorphism_add_inplace( @@ -347,7 +347,61 @@ where { unsafe { let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - Self::keyswitch_private(self, true, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + } + } + + pub fn automorphism_sub_ab( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, lhs, &rhs.key, scratch); + } + + pub fn automorphism_sub_ab_inplace( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); + } + } + + pub fn automorphism_sub_ba( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, lhs, &rhs.key, scratch); + } + + pub fn automorphism_sub_ba_inplace( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); } } @@ -420,12 +474,11 @@ where VecZnx: VecZnxToRef, MatZnxDft: MatZnxDftToRef, { - Self::keyswitch_private(self, false, 0, module, lhs, rhs, scratch); + Self::keyswitch_private::<_, _, 0>(self, 0, module, lhs, rhs, scratch); } - pub(crate) fn keyswitch_private( + pub(crate) fn keyswitch_private( &mut self, - add_self: bool, apply_auto: i64, module: &Module, lhs: &GLWECiphertext, @@ -481,8 +534,11 @@ where module.vec_znx_big_automorphism_inplace(apply_auto, &mut res_big, i); } - if add_self { - module.vec_znx_big_add_small_inplace(&mut res_big, i, lhs, i); + match OP{ + 1=> module.vec_znx_big_add_small_inplace(&mut res_big, i, lhs, i), + 2=> module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, lhs, i), + 3=> module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, lhs, i), + _=>{}, } module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); }); diff --git a/core/src/glwe_ops.rs b/core/src/glwe_ops.rs index 95357ed..61f9ad2 100644 --- a/core/src/glwe_ops.rs +++ b/core/src/glwe_ops.rs @@ -175,10 +175,7 @@ where self.set_k(a.k()); } - pub fn rotate_inplace(&mut self, module: &Module, k: i64) - where - A: VecZnxToRef + Infos, - { + pub fn rotate_inplace(&mut self, module: &Module, k: i64){ #[cfg(debug_assertions)] { assert_eq!(self.n(), module.n()); @@ -242,3 +239,9 @@ where }); } } + +impl GLWECiphertext>{ + pub fn rsh_scratch_space(module: &Module) -> usize{ + VecZnx::rsh_scratch_space(module.n()) + } +} \ No newline at end of file diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs index 985f90d..105dace 100644 --- a/core/src/tensor_key.rs +++ b/core/src/tensor_key.rs @@ -56,7 +56,7 @@ impl TensorKey { } impl TensorKey, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + pub fn generate_from_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { module.bytes_of_scalar_znx_dft(1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, rank, size) } } @@ -65,7 +65,7 @@ impl TensorKey where MatZnxDft: MatZnxDftToMut, { - pub fn encrypt_sk( + pub fn generate_from_sk( &mut self, module: &Module, sk_dft: &SecretKeyFourier, diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs index 45855a5..5a244ba 100644 --- a/core/src/test_fft64/ggsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -160,7 +160,7 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) - | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | TensorKey::generate_from_sk_scratch_space(&module, rank, ksk.size()) | GGSWCiphertext::keyswitch_scratch_space( &module, ct_out.size(), @@ -194,7 +194,7 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) sigma, scratch.borrow(), ); - tsk.encrypt_sk( + tsk.generate_from_sk( &module, &sk_out_dft, &mut source_xa, @@ -286,7 +286,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sig GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()) | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) - | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | TensorKey::generate_from_sk_scratch_space(&module, rank, ksk.size()) | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct.size(), ksk.size(), tsk.size(), rank), ); @@ -313,7 +313,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sig sigma, scratch.borrow(), ); - tsk.encrypt_sk( + tsk.generate_from_sk( &module, &sk_out_dft, &mut source_xa, @@ -455,7 +455,7 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) | AutomorphismKey::generate_from_sk_scratch_space(&module, rank, auto_key.size()) - | TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size()) + | TensorKey::generate_from_sk_scratch_space(&module, rank, tensor_key.size()) | GGSWCiphertext::automorphism_scratch_space( &module, ct_out.size(), @@ -483,7 +483,7 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma, scratch.borrow(), ); - tensor_key.encrypt_sk( + tensor_key.generate_from_sk( &module, &sk_dft, &mut source_xa, @@ -575,7 +575,7 @@ fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()) | AutomorphismKey::generate_from_sk_scratch_space(&module, rank, auto_key.size()) - | TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size()) + | TensorKey::generate_from_sk_scratch_space(&module, rank, tensor_key.size()) | GGSWCiphertext::automorphism_inplace_scratch_space(&module, ct.size(), auto_key.size(), tensor_key.size(), rank), ); @@ -596,7 +596,7 @@ fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: sigma, scratch.borrow(), ); - tensor_key.encrypt_sk( + tensor_key.generate_from_sk( &module, &sk_dft, &mut source_xa, diff --git a/core/src/test_fft64/tensor_key.rs b/core/src/test_fft64/tensor_key.rs index a897253..2a3d40d 100644 --- a/core/src/test_fft64/tensor_key.rs +++ b/core/src/test_fft64/tensor_key.rs @@ -30,7 +30,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::generate_from_sk_scratch_space( &module, rank, tensor_key.size(), @@ -40,7 +40,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize sk.fill_ternary_prob(0.5, &mut source_xs); let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::alloc(&module, rank); - sk_dft.dft(&module, &sk); + sk_dft.dft(generate_from_sksk); tensor_key.encrypt_sk( &module,