From fef2a2fc2736f24ab10c53547606a6c0afe4761a Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Tue, 21 Oct 2025 10:47:46 +0200 Subject: [PATCH] fixed scratch API --- poulpy-core/src/conversion/gglwe_to_ggsw.rs | 2 +- poulpy-core/src/conversion/glwe_to_lwe.rs | 2 +- poulpy-core/src/conversion/lwe_to_glwe.rs | 17 ++--- .../src/encryption/compressed/gglwe.rs | 2 +- poulpy-core/src/encryption/compressed/ggsw.rs | 4 +- .../compressed/glwe_automorphism_key.rs | 4 +- .../compressed/glwe_switching_key.rs | 4 +- .../encryption/compressed/glwe_tensor_key.rs | 4 +- poulpy-core/src/encryption/gglwe.rs | 3 +- poulpy-core/src/encryption/ggsw.rs | 2 +- poulpy-core/src/encryption/glwe.rs | 6 +- .../src/encryption/glwe_automorphism_key.rs | 8 +-- .../src/encryption/glwe_switching_key.rs | 4 +- poulpy-core/src/encryption/glwe_tensor_key.rs | 10 +-- .../encryption/glwe_to_lwe_switching_key.rs | 2 +- .../src/encryption/lwe_switching_key.rs | 4 +- .../encryption/lwe_to_glwe_switching_key.rs | 2 +- poulpy-core/src/external_product/glwe.rs | 4 +- poulpy-core/src/glwe_packing.rs | 8 +-- poulpy-core/src/glwe_trace.rs | 15 ++--- poulpy-core/src/keyswitching/glwe.rs | 4 +- poulpy-core/src/keyswitching/lwe.rs | 30 ++++----- .../src/layouts/prepared/glwe_public_key.rs | 3 +- .../src/layouts/prepared/glwe_secret.rs | 5 +- poulpy-core/src/scratch.rs | 66 +++++++------------ .../tests/test_suite/encryption/gglwe_ct.rs | 3 +- poulpy-hal/src/api/scratch.rs | 34 ++++------ .../ciphertexts/block_prepared.rs | 13 ++-- 28 files changed, 112 insertions(+), 153 deletions(-) diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs index 7ad73e6..87b6791 100644 --- a/poulpy-core/src/conversion/gglwe_to_ggsw.rs +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -186,7 +186,7 @@ where self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); } } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self, 1, a_size); + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); for i in 0..cols { self.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index 461d07d..fbf5912 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -85,7 +85,7 @@ where rank: Rank(1), }; - let (mut tmp_glwe, scratch_1) = scratch.take_glwe(self, &glwe_layout); + let (mut tmp_glwe, scratch_1) = scratch.take_glwe(&glwe_layout); self.glwe_keyswitch(&mut tmp_glwe, a, key, scratch_1); self.lwe_sample_extract(res, &tmp_glwe); } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index 90c71a2..c759ee5 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -50,15 +50,12 @@ where assert_eq!(ksk.n(), self.n() as u32); assert!(lwe.n() <= self.n() as u32); - let (mut glwe, scratch_1) = scratch.take_glwe( - self, - &GLWELayout { - n: ksk.n(), - base2k: ksk.base2k(), - k: lwe.k(), - rank: 1u32.into(), - }, - ); + let (mut glwe, scratch_1) = scratch.take_glwe(&GLWELayout { + n: ksk.n(), + base2k: ksk.base2k(), + k: lwe.k(), + rank: 1u32.into(), + }); glwe.data.zero(); let n_lwe: usize = lwe.n().into(); @@ -70,7 +67,7 @@ where glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); } } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self, 1, lwe.size()); + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, lwe.size()); a_conv.zero(); for j in 0..lwe.size() { let data_lwe: &[i64] = lwe.data.at(0, j); diff --git a/poulpy-core/src/encryption/compressed/gglwe.rs b/poulpy-core/src/encryption/compressed/gglwe.rs index 1f502c5..1dfbf58 100644 --- a/poulpy-core/src/encryption/compressed/gglwe.rs +++ b/poulpy-core/src/encryption/compressed/gglwe.rs @@ -142,7 +142,7 @@ where let mut source_xa = Source::new(seed); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(self, res); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(res); for col_i in 0..rank_in { for d_i in 0..dnum { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt diff --git a/poulpy-core/src/encryption/compressed/ggsw.rs b/poulpy-core/src/encryption/compressed/ggsw.rs index e1695b8..14b0de5 100644 --- a/poulpy-core/src/encryption/compressed/ggsw.rs +++ b/poulpy-core/src/encryption/compressed/ggsw.rs @@ -8,7 +8,7 @@ use crate::{ ScratchTakeCore, encryption::{GGSWEncryptSk, GLWEEncryptSkInternal, SIGMA}, layouts::{ - GGSWCompressedSeedMut, GGSWInfos, GLWEInfos, LWEInfos, + GGSWCompressedSeedMut, GGSWInfos, LWEInfos, compressed::{GGSWCompressed, GGSWCompressedToMut}, prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, }, @@ -107,7 +107,7 @@ where println!("res.seed: {:?}", res.seed); - let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(self, &res.glwe_layout()); + let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(res); let mut source = Source::new(seed_xa); diff --git a/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs b/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs index 9c9046d..df57f0f 100644 --- a/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs +++ b/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs @@ -98,8 +98,8 @@ where let (mut sk_out_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank()); { - let (mut sk_out, _) = scratch_1.take_glwe_secret(self, sk.rank()); - for i in 0..res.rank_out().into() { + let (mut sk_out, _) = scratch_1.take_glwe_secret(self.n().into(), sk.rank()); + for i in 0..sk.rank().into() { self.vec_znx_automorphism( self.galois_element_inv(p), &mut sk_out.data.as_vec_znx_mut(), diff --git a/poulpy-core/src/encryption/compressed/glwe_switching_key.rs b/poulpy-core/src/encryption/compressed/glwe_switching_key.rs index c492cfd..e1c40ea 100644 --- a/poulpy-core/src/encryption/compressed/glwe_switching_key.rs +++ b/poulpy-core/src/encryption/compressed/glwe_switching_key.rs @@ -102,7 +102,7 @@ where self.gglwe_compressed_encrypt_sk_tmp_bytes(res) ); - let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self, sk_in.rank().into()); + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into()); for i in 0..sk_in.rank().into() { self.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), @@ -114,7 +114,7 @@ where let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank()); { - let (mut tmp, _) = scratch_2.take_scalar_znx(self, 1); + let (mut tmp, _) = scratch_2.take_scalar_znx(self.n(), 1); for i in 0..sk_out.rank().into() { self.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); self.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); diff --git a/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs index 9456e89..b50b47d 100644 --- a/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs +++ b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs @@ -99,7 +99,7 @@ where R: GGLWEInfos + TensorKeyCompressedAtMut, S: GLWESecretToRef + GetDistribution, { - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank_out()); + let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); sk_dft_prep.prepare(self, sk); let sk: &GLWESecret<&[u8]> = &sk.to_ref(); @@ -120,7 +120,7 @@ where } let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self, Rank(1)); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1)); let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); let mut source_xa: Source = Source::new(seed_xa); diff --git a/poulpy-core/src/encryption/gglwe.rs b/poulpy-core/src/encryption/gglwe.rs index 01f1bc3..ba78cde 100644 --- a/poulpy-core/src/encryption/gglwe.rs +++ b/poulpy-core/src/encryption/gglwe.rs @@ -6,7 +6,6 @@ use poulpy_hal::{ use crate::{ GLWEEncryptSk, ScratchTakeCore, - layouts::GLWEInfos, layouts::{ GGLWE, GGLWEInfos, GGLWEToMut, GLWEPlaintext, LWEInfos, prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, @@ -143,7 +142,7 @@ where let base2k: usize = res.base2k().into(); let rank_in: usize = res.rank_in().into(); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(self, &res.glwe_layout()); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(res); // For each input column (i.e. rank) produces a GGLWE of rank_out+1 columns // // Example for ksk rank 2 to rank 3: diff --git a/poulpy-core/src/encryption/ggsw.rs b/poulpy-core/src/encryption/ggsw.rs index c3bef71..85b8be5 100644 --- a/poulpy-core/src/encryption/ggsw.rs +++ b/poulpy-core/src/encryption/ggsw.rs @@ -109,7 +109,7 @@ where let dsize: usize = res.dsize().into(); let cols: usize = (rank + 1).into(); - let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(self, res); + let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(res); for row_i in 0..res.dnum().into() { tmp_pt.data.zero(); diff --git a/poulpy-core/src/encryption/glwe.rs b/poulpy-core/src/encryption/glwe.rs index d4aca7d..fc59e4b 100644 --- a/poulpy-core/src/encryption/glwe.rs +++ b/poulpy-core/src/encryption/glwe.rs @@ -372,7 +372,7 @@ where let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self, 1); { - let (mut u, _) = scratch_1.take_scalar_znx(self, 1); + let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1); match pk.dist() { Distribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ @@ -499,11 +499,11 @@ where let size: usize = ct.size(); - let (mut c0, scratch_1) = scratch.take_vec_znx(self, 1, size); + let (mut c0, scratch_1) = scratch.take_vec_znx(self.n(), 1, size); c0.zero(); { - let (mut ci, scratch_2) = scratch_1.take_vec_znx(self, 1, size); + let (mut ci, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, size); // ct[i] = uniform // ct[0] -= c[i] * s[i], diff --git a/poulpy-core/src/encryption/glwe_automorphism_key.rs b/poulpy-core/src/encryption/glwe_automorphism_key.rs index 68782e3..cee3163 100644 --- a/poulpy-core/src/encryption/glwe_automorphism_key.rs +++ b/poulpy-core/src/encryption/glwe_automorphism_key.rs @@ -7,8 +7,8 @@ use poulpy_hal::{ use crate::{ GGLWEEncryptSk, ScratchTakeCore, layouts::{ - AutomorphismKey, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, GLWESecret, GLWESecretPrepared, GLWESecretPreparedFactory, - GLWESecretToRef, LWEInfos, SetGaloisElement, + AutomorphismKey, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, GLWESecret, GLWESecretPrepared, + GLWESecretPreparedFactory, GLWESecretToRef, LWEInfos, SetGaloisElement, }, }; @@ -115,8 +115,8 @@ where let (mut sk_out_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank()); { - let (mut sk_out, _) = scratch_1.take_glwe_secret(self, sk.rank()); - for i in 0..res.rank().into() { + let (mut sk_out, _) = scratch_1.take_glwe_secret(sk.n(), sk.rank()); + for i in 0..sk.rank().into() { self.vec_znx_automorphism( self.galois_element_inv(p), &mut sk_out.data.as_vec_znx_mut(), diff --git a/poulpy-core/src/encryption/glwe_switching_key.rs b/poulpy-core/src/encryption/glwe_switching_key.rs index fc19a71..62d60c0 100644 --- a/poulpy-core/src/encryption/glwe_switching_key.rs +++ b/poulpy-core/src/encryption/glwe_switching_key.rs @@ -109,7 +109,7 @@ where self.glwe_switching_key_encrypt_sk_tmp_bytes(res) ); - let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self, sk_in.rank().into()); + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into()); for i in 0..sk_in.rank().into() { self.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), @@ -121,7 +121,7 @@ where let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank()); { - let (mut tmp, _) = scratch_2.take_scalar_znx(self, 1); + let (mut tmp, _) = scratch_2.take_scalar_znx(self.n(), 1); for i in 0..sk_out.rank().into() { self.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); self.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); diff --git a/poulpy-core/src/encryption/glwe_tensor_key.rs b/poulpy-core/src/encryption/glwe_tensor_key.rs index 758afcc..e151b8e 100644 --- a/poulpy-core/src/encryption/glwe_tensor_key.rs +++ b/poulpy-core/src/encryption/glwe_tensor_key.rs @@ -35,7 +35,7 @@ impl TensorKey { scratch: &mut Scratch, ) where M: TensorKeyEncryptSk, - S: GLWESecretToRef + GetDistribution, + S: GLWESecretToRef + GetDistribution + GLWEInfos, Scratch: ScratchTakeCore, { module.tensor_key_encrypt_sk(self, sk, source_xa, source_xe, scratch); @@ -56,7 +56,7 @@ pub trait TensorKeyEncryptSk { scratch: &mut Scratch, ) where R: TensorKeyToMut, - S: GLWESecretToRef + GetDistribution; + S: GLWESecretToRef + GetDistribution + GLWEInfos; } impl TensorKeyEncryptSk for Module @@ -93,14 +93,14 @@ where scratch: &mut Scratch, ) where R: TensorKeyToMut, - S: GLWESecretToRef + GetDistribution, + S: GLWESecretToRef + GetDistribution + GLWEInfos, { let res: &mut TensorKey<&mut [u8]> = &mut res.to_mut(); // let n: RingDegree = sk.n(); let rank: Rank = res.rank_out(); - let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, rank); + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank()); sk_prepared.prepare(self, sk); let sk: &GLWESecret<&[u8]> = &sk.to_ref(); @@ -115,7 +115,7 @@ where }); let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self, Rank(1)); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1)); let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); (0..rank.into()).for_each(|i| { diff --git a/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs b/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs index 36ecff2..30a46a8 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs @@ -100,7 +100,7 @@ where let (mut sk_lwe_as_glwe_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, Rank(1)); { - let (mut sk_lwe_as_glwe, scratch_2) = scratch_1.take_glwe_secret(self, Rank(1)); + let (mut sk_lwe_as_glwe, scratch_2) = scratch_1.take_glwe_secret(self.n().into(), sk_lwe_as_glwe_prep.rank()); sk_lwe_as_glwe.data.zero(); sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); self.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_2); diff --git a/poulpy-core/src/encryption/lwe_switching_key.rs b/poulpy-core/src/encryption/lwe_switching_key.rs index daee764..859f0c9 100644 --- a/poulpy-core/src/encryption/lwe_switching_key.rs +++ b/poulpy-core/src/encryption/lwe_switching_key.rs @@ -111,8 +111,8 @@ where assert!(sk_lwe_out.n().0 <= res.n().0); assert!(res.n() <= self.n() as u32); - let (mut sk_in_glwe, scratch_1) = scratch.take_glwe_secret(self, Rank(1)); - let (mut sk_out_glwe, scratch_2) = scratch_1.take_glwe_secret(self, Rank(1)); + let (mut sk_in_glwe, scratch_1) = scratch.take_glwe_secret(self.n().into(), Rank(1)); + let (mut sk_out_glwe, scratch_2) = scratch_1.take_glwe_secret(self.n().into(), Rank(1)); sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n().into()].copy_from_slice(sk_lwe_out.data.at(0, 0)); sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n().into()..].fill(0); diff --git a/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs b/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs index 588875d..524bb1a 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs @@ -99,7 +99,7 @@ where assert!(sk_lwe.n().0 <= self.n() as u32); - let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(self, Rank(1)); + let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(self.n().into(), Rank(1)); sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n().into()..].fill(0); self.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); diff --git a/poulpy-core/src/external_product/glwe.rs b/poulpy-core/src/external_product/glwe.rs index ab9968c..a1cd5ee 100644 --- a/poulpy-core/src/external_product/glwe.rs +++ b/poulpy-core/src/external_product/glwe.rs @@ -147,7 +147,7 @@ where } } } else { - let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self, cols, a_size); + let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size); for j in 0..cols { self.vec_znx_normalize( @@ -262,7 +262,7 @@ where } } } else { - let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self, cols, a_size); + let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size); for j in 0..cols { self.vec_znx_normalize( diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 1982269..09540b2 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -331,7 +331,7 @@ fn combine( // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if acc.value { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe(module, a); + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); // a = a * X^-t module.glwe_rotate_inplace(-t, a, scratch_1); @@ -371,7 +371,7 @@ fn combine( } } } else if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe(module, a); + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); module.glwe_rotate(t, &mut tmp_b, b); module.glwe_rsh(1, &mut tmp_b, scratch_1); @@ -415,7 +415,7 @@ fn pack_internal( let t: i64 = 1 << (a.n().log2() - i - 1); if let Some(b) = b.as_deref_mut() { - let (mut tmp_b, scratch_1) = scratch.take_glwe(module, a); + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); // a = a * X^-t module.glwe_rotate_inplace(-t, a, scratch_1); @@ -449,7 +449,7 @@ fn pack_internal( } else if let Some(b) = b.as_deref_mut() { let t: i64 = 1 << (b.n().log2() - i - 1); - let (mut tmp_b, scratch_1) = scratch.take_glwe(module, b); + let (mut tmp_b, scratch_1) = scratch.take_glwe(b); module.glwe_rotate(t, &mut tmp_b, b); module.glwe_rsh(1, &mut tmp_b, scratch_1); diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 421692f..80cf84c 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -140,15 +140,12 @@ where } if res.base2k() != basek_ksk { - let (mut self_conv, scratch_1) = scratch.take_glwe( - self, - &GLWELayout { - n: self.n().into(), - base2k: basek_ksk, - k: res.k(), - rank: res.rank(), - }, - ); + let (mut self_conv, scratch_1) = scratch.take_glwe(&GLWELayout { + n: self.n().into(), + base2k: basek_ksk, + k: res.k(), + rank: res.rank(), + }); for j in 0..(res.rank() + 1).into() { self.vec_znx_normalize( diff --git a/poulpy-core/src/keyswitching/glwe.rs b/poulpy-core/src/keyswitching/glwe.rs index be81ebf..a021777 100644 --- a/poulpy-core/src/keyswitching/glwe.rs +++ b/poulpy-core/src/keyswitching/glwe.rs @@ -278,7 +278,7 @@ where module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1); }); } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module, 1, a_size); + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, a_size); (0..cols - 1).for_each(|col_i| { module.vec_znx_normalize( base2k_out, @@ -324,7 +324,7 @@ where } } } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module, cols - 1, a_size); + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), cols - 1, a_size); for j in 0..cols - 1 { module.vec_znx_normalize( base2k_out, diff --git a/poulpy-core/src/keyswitching/lwe.rs b/poulpy-core/src/keyswitching/lwe.rs index 249ccca..bf5abf5 100644 --- a/poulpy-core/src/keyswitching/lwe.rs +++ b/poulpy-core/src/keyswitching/lwe.rs @@ -87,26 +87,20 @@ where let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize; - let (mut glwe_in, scratch_1) = scratch.take_glwe( - self, - &GLWELayout { - n: ksk.n(), - base2k: a.base2k(), - k: max_k, - rank: Rank(1), - }, - ); + let (mut glwe_in, scratch_1) = scratch.take_glwe(&GLWELayout { + n: ksk.n(), + base2k: a.base2k(), + k: max_k, + rank: Rank(1), + }); glwe_in.data.zero(); - let (mut glwe_out, scratch_1) = scratch_1.take_glwe( - self, - &GLWELayout { - n: ksk.n(), - base2k: res.base2k(), - k: max_k, - rank: Rank(1), - }, - ); + let (mut glwe_out, scratch_1) = scratch_1.take_glwe(&GLWELayout { + n: ksk.n(), + base2k: res.base2k(), + k: max_k, + rank: Rank(1), + }); let n_lwe: usize = a.n().into(); diff --git a/poulpy-core/src/layouts/prepared/glwe_public_key.rs b/poulpy-core/src/layouts/prepared/glwe_public_key.rs index b20682d..fab30bb 100644 --- a/poulpy-core/src/layouts/prepared/glwe_public_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_public_key.rs @@ -93,7 +93,8 @@ where } } -impl GLWEPublicKeyPreparedFactory for Module where Self: VecZnxDftAlloc + VecZnxDftBytesOf + VecZnxDftApply {} +impl GLWEPublicKeyPreparedFactory for Module where Self: VecZnxDftAlloc + VecZnxDftBytesOf + VecZnxDftApply +{} impl GLWEPublicKeyPrepared, B> { pub fn alloc_from_infos(module: &M, infos: &A) -> Self diff --git a/poulpy-core/src/layouts/prepared/glwe_secret.rs b/poulpy-core/src/layouts/prepared/glwe_secret.rs index 412733c..0e38221 100644 --- a/poulpy-core/src/layouts/prepared/glwe_secret.rs +++ b/poulpy-core/src/layouts/prepared/glwe_secret.rs @@ -96,7 +96,10 @@ where } } -impl GLWESecretPreparedFactory for Module where Self: GetDegree + SvpPPolBytesOf + SvpPPolAlloc + SvpPrepare {} +impl GLWESecretPreparedFactory for Module where + Self: GetDegree + SvpPPolBytesOf + SvpPPolAlloc + SvpPrepare +{ +} impl GLWESecretPrepared, B> { pub fn alloc_from_infos(module: &M, infos: &A) -> Self diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index 7261410..880af1b 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -19,13 +19,11 @@ pub trait ScratchTakeCore where Self: ScratchTakeBasic + ScratchAvailable, { - fn take_glwe(&mut self, module: &M, infos: &A) -> (GLWE<&mut [u8]>, &mut Self) + fn take_glwe(&mut self, infos: &A) -> (GLWE<&mut [u8]>, &mut Self) where A: GLWEInfos, - M: ModuleN, { - assert_eq!(module.n() as u32, infos.n()); - let (data, scratch) = self.take_vec_znx(module, (infos.rank() + 1).into(), infos.size()); + let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size()); ( GLWE { k: infos.k(), @@ -36,28 +34,25 @@ where ) } - fn take_glwe_slice(&mut self, module: &M, size: usize, infos: &A) -> (Vec>, &mut Self) + fn take_glwe_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) where A: GLWEInfos, - M: ModuleN, { let mut scratch: &mut Self = self; let mut cts: Vec> = Vec::with_capacity(size); for _ in 0..size { - let (ct, new_scratch) = scratch.take_glwe(module, infos); + let (ct, new_scratch) = scratch.take_glwe(infos); scratch = new_scratch; cts.push(ct); } (cts, scratch) } - fn take_glwe_plaintext(&mut self, module: &M, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) + fn take_glwe_plaintext(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) where A: GLWEInfos, - M: ModuleN, { - assert_eq!(module.n() as u32, infos.n()); - let (data, scratch) = self.take_vec_znx(module, 1, infos.size()); + let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size()); ( GLWEPlaintext { k: infos.k(), @@ -68,14 +63,12 @@ where ) } - fn take_gglwe(&mut self, module: &M, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self) + fn take_gglwe(&mut self, infos: &A) -> (GGLWE<&mut [u8]>, &mut Self) where A: GGLWEInfos, - M: ModuleN, { - assert_eq!(module.n() as u32, infos.n()); let (data, scratch) = self.take_mat_znx( - module, + infos.n().into(), infos.dnum().0.div_ceil(infos.dsize().0) as usize, infos.rank_in().into(), (infos.rank_out() + 1).into(), @@ -116,14 +109,12 @@ where ) } - fn take_ggsw(&mut self, module: &M, infos: &A) -> (GGSW<&mut [u8]>, &mut Self) + fn take_ggsw(&mut self, infos: &A) -> (GGSW<&mut [u8]>, &mut Self) where A: GGSWInfos, - M: ModuleN, { - assert_eq!(module.n() as u32, infos.n()); let (data, scratch) = self.take_mat_znx( - module, + infos.n().into(), infos.dnum().into(), (infos.rank() + 1).into(), (infos.rank() + 1).into(), @@ -184,13 +175,11 @@ where (cts, scratch) } - fn take_glwe_public_key(&mut self, module: &M, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) + fn take_glwe_public_key(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) where A: GLWEInfos, - M: ModuleN, { - assert_eq!(module.n() as u32, infos.n()); - let (data, scratch) = self.take_glwe(module, infos); + let (data, scratch) = self.take_glwe(infos); ( GLWEPublicKey { dist: Distribution::NONE, @@ -232,11 +221,8 @@ where ) } - fn take_glwe_secret(&mut self, module: &M, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) - where - M: ModuleN, - { - let (data, scratch) = self.take_scalar_znx(module, rank.into()); + fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_scalar_znx(n.into(), rank.into()); ( GLWESecret { data, @@ -260,13 +246,11 @@ where ) } - fn take_glwe_switching_key(&mut self, module: &M, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self) + fn take_glwe_switching_key(&mut self, infos: &A) -> (GLWESwitchingKey<&mut [u8]>, &mut Self) where A: GGLWEInfos, - M: ModuleN, { - assert_eq!(module.n() as u32, infos.n()); - let (data, scratch) = self.take_gglwe(module, infos); + let (data, scratch) = self.take_gglwe(infos); ( GLWESwitchingKey { key: data, @@ -298,17 +282,15 @@ where ) } - fn take_gglwe_automorphism_key(&mut self, module: &M, infos: &A) -> (AutomorphismKey<&mut [u8]>, &mut Self) + fn take_glwe_automorphism_key(&mut self, infos: &A) -> (AutomorphismKey<&mut [u8]>, &mut Self) where A: GGLWEInfos, - M: ModuleN, { - assert_eq!(module.n() as u32, infos.n()); - let (data, scratch) = self.take_gglwe(module, infos); + let (data, scratch) = self.take_gglwe(infos); (AutomorphismKey { key: data, p: 0 }, scratch) } - fn take_gglwe_automorphism_key_prepared( + fn take_glwe_automorphism_key_prepared( &mut self, module: &M, infos: &A, @@ -322,12 +304,10 @@ where (GLWEAutomorphismKeyPrepared { key: data, p: 0 }, scratch) } - fn take_tensor_key(&mut self, module: &M, infos: &A) -> (TensorKey<&mut [u8]>, &mut Self) + fn take_glwe_tensor_key(&mut self, infos: &A) -> (TensorKey<&mut [u8]>, &mut Self) where A: GGLWEInfos, - M: ModuleN, { - assert_eq!(module.n() as u32, infos.n()); assert_eq!( infos.rank_in(), infos.rank_out(), @@ -342,19 +322,19 @@ where ksk_infos.rank_in = Rank(1); if pairs != 0 { - let (gglwe, s) = scratch.take_gglwe(module, &ksk_infos); + let (gglwe, s) = scratch.take_gglwe(&ksk_infos); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.take_gglwe(module, &ksk_infos); + let (gglwe, s) = scratch.take_gglwe(&ksk_infos); scratch = s; keys.push(gglwe); } (TensorKey { keys }, scratch) } - fn take_gglwe_tensor_key_prepared(&mut self, module: &M, infos: &A) -> (TensorKeyPrepared<&mut [u8], B>, &mut Self) + fn take_glwe_tensor_key_prepared(&mut self, module: &M, infos: &A) -> (TensorKeyPrepared<&mut [u8], B>, &mut Self) where A: GGLWEInfos, M: ModuleN + VmpPMatBytesOf, diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs index 9e8aafe..2b64f02 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs @@ -9,7 +9,8 @@ use crate::{ decryption::GLWEDecrypt, encryption::SIGMA, layouts::{ - GGLWELayout, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress, + GGLWELayout, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyCompressed, + GLWESwitchingKeyDecompress, prepared::{GGLWEPreparedFactory, GLWESecretPrepared}, }, noise::GGLWENoise, diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index fb17266..714a58d 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -34,13 +34,11 @@ pub trait ScratchTakeBasic where Self: TakeSlice, { - fn take_scalar_znx(&mut self, module: &M, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) - where - M: ModuleN, + fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { - let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(module.n(), cols)); + let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols)); ( - ScalarZnx::from_data(take_slice, module.n(), cols), + ScalarZnx::from_data(take_slice, n, cols), rem_slice, ) } @@ -53,13 +51,10 @@ where (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice) } - fn take_vec_znx(&mut self, module: &M, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) - where - M: ModuleN, - { - let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(module.n(), cols, size)); + fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self){ + let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size)); ( - VecZnx::from_data(take_slice, module.n(), cols, size), + VecZnx::from_data(take_slice, n, cols, size), rem_slice, ) } @@ -107,14 +102,11 @@ where (slice, scratch) } - fn take_vec_znx_slice(&mut self, module: &M, len: usize, cols: usize, size: usize) -> (Vec>, &mut Self) - where - M: ModuleN, - { + fn take_vec_znx_slice(&mut self, n: usize, len: usize, cols: usize, size: usize) -> (Vec>, &mut Self){ let mut scratch: &mut Self = self; let mut slice: Vec> = Vec::with_capacity(len); for _ in 0..len { - let (znx, new_scratch) = scratch.take_vec_znx(module, cols, size); + let (znx, new_scratch) = scratch.take_vec_znx(n, cols, size); scratch = new_scratch; slice.push(znx); } @@ -139,20 +131,18 @@ where ) } - fn take_mat_znx( + fn take_mat_znx( &mut self, - module: &M, + n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, ) -> (MatZnx<&mut [u8]>, &mut Self) - where - M: ModuleN, { - let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(module.n(), rows, cols_in, cols_out, size)); + let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size)); ( - MatZnx::from_data(take_slice, module.n(), rows, cols_in, cols_out, size), + MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), rem_slice, ) } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs index c0cd09f..aa70910 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs @@ -3,8 +3,7 @@ use std::marker::PhantomData; use poulpy_core::layouts::{Base2K, Dnum, Dsize, GGSWInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared}; #[cfg(test)] use poulpy_core::{ - TakeGGSW, - layouts::{GGSW, prepared::GLWESecretPrepared}, + layouts::{prepared::GLWESecretPrepared, GGSW}, ScratchTakeCore, }; use poulpy_hal::{ api::VmpPMatAlloc, @@ -13,13 +12,12 @@ use poulpy_hal::{ #[cfg(test)] use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, + SvpApplyDftToDftInplace, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPrepare, }, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, source::Source, }; @@ -137,7 +135,7 @@ impl FheUintBlocksPrep + VecZnxSub + VmpPrepare, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGGSW + TakeScalarZnx, + Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -146,11 +144,11 @@ impl FheUintBlocksPrep FheUintBlocksPrepDebug { + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxSubInplace, - BE: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { for (i, ggsw) in self.blocks.iter().enumerate() { use poulpy_hal::layouts::{ScalarZnx, ZnxViewMut};