diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d2aa5ca..2e18c08 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,6 +62,7 @@ jobs: - name: Tests (AVX enabled) if: steps.avxcheck.outputs.supported == 'true' run: | + RUSTDOCFLAGS="-C target-feature=+avx2 -C target-feature=+fma" \ RUSTFLAGS="-C target-feature=+avx2,+fma" \ cargo test --workspace --features enable-avx diff --git a/CHANGELOG.md b/CHANGELOG.md index a38eb85..4026788 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,101 @@ # CHANGELOG -## [0.4.0] - 2025-10-27 +## [0.4.1] - 2025-11-26 + +### Summary +- Update convolution API to match spqlios-arithmetic & removed API for bivariate tensoring. + +## `poulpy-hal` +- Removed `Backend` generic from `VecZnxBigAllocBytesImpl`. +- Add `CnvPVecL` and `CnvPVecR` structs. +- Add `CnvPVecBytesOf` and `CnvPVecAlloc` traits. +- Add `Convolution` trait, which regroups the following methods: + - `cnv_prepare_left_tmp_bytes` + - `cnv_prepare_left` + - `cnv_prepare_right_tmp_bytes` + - `cnv_prepare_right` + - `cnv_by_const_apply` + - `cnv_by_const_apply_tmp_bytes` + - `cnv_apply_dft_tmp_bytes` + - `cnv_apply_dft` + - `cnv_pairwise_apply_dft_tmp_bytes` + - `cnv_pairwise_apply_dft` +- Add the following Reim4 traits: + - `Reim4Convolution` + - `Reim4Convolution1Coeff` + - `Reim4Convolution2Coeffs` + - `Reim4Save1BlkContiguous` +- Add the following traits: + - `i64Save1BlkContiguous` + - `i64Extract1BlkContiguous` + - `i64ConvolutionByConst1Coeff` + - `i64ConvolutionByConst2Coeffs` +- Update signature `Reim4Extract1Blk` to `Reim4Extract1BlkContiguous`. +- Add fft64 backend reference code for + - `reim4_save_1blk_to_reim_contiguous_ref` + - `reim4_convolution_1coeff_ref` + - `reim4_convolution_2coeffs_ref` + - `convolution_prepare_left` + - `convolution_prepare_right` + - `convolution_apply_dft_tmp_bytes` + - `convolution_apply_dft` + - `convolution_pairwise_apply_dft_tmp_bytes` + - `convolution_pairwise_apply_dft` + - `convolution_by_const_apply_tmp_bytes` + - `convolution_by_const_apply` +- Add `take_cnv_pvec_left` and `take_cnv_pvec_right` methods to `ScratchTakeBasic` trait. +- Add the following tests methods for convolution: + - `test_convolution` + - `test_convolution_by_const` + - `test_convolution_pairwise` +- Add the following benches methods for convolution: + - `bench_cnv_prepare_left` + - `bench_cnv_prepare_right` + - `bench_cnv_apply_dft` + - `bench_cnv_pairwise_apply_dft` + - `bench_cnv_by_const` +- Update normalization API and OEP to take `res_offset: i64`. This allows the user to specify a bit-shift (positive or negative) applied to the normalization. Behavior-wise, the bit-shift is applied before the normalization (i.e. before applying mod 1 reduction). Since this is an API break, opportunity was taken to also re-order inputs for better consistency. + - `VecZnxNormalize` & `VecZnxNormalizeImpl` + - `VecZnxBigNormalize` & `VecZnxBigNormalizeImpl` + This change completes the road to unlocking full support for cross-base2k normalization, along with arbitrary positive/negative offset. Code is not ensured to be optimal, but correctness is ensured. + +## `poulpy-cpu-ref` +- Implemented `ConvolutionImpl` OPE on `FFT64Ref` backend. +- Add benchmark for convolution. +- Add test for convolution. + +## `poulpy-cpu-avx` +- Implemented `ConvolutionImpl` OPE on `FFT64Avx` backend. +- Add benchmark for convolution. +- Add test for convolution. +- Add fft64 AVX code for + - `reim4_save_1blk_to_reim_contiguous_avx` + - `reim4_convolution_1coeff_avx` + - `reim4_convolution_2coeffs_avx` + +## `poulpy-core` +- Renamed `size` to `limbs`. +- Add `GLWEMulPlain` trait: + - `glwe_mul_plain_tmp_bytes` + - `glwe_mul_plain` + - `glwe_mul_plain_inplace` +- Add `GLWEMulConst` trait: + - `glwe_mul_const_tmp_bytes` + - `glwe_mul_const` + - `glwe_mul_const_inplace` +- Add `GLWETensoring` trait: + - `glwe_tensor_apply_tmp_bytes` + - `glwe_tensor_apply` + - `glwe_tensor_relinearize_tmp_bytes` + - `glwe_tensor_relinearize` +- Add method tests: + - `test_glwe_tensoring` + +## [0.4.0] - 2025-11-20 ### Summary - Full support for base2k operations. -- Many improvments to BDD arithmetic. +- Many improvements to BDD arithmetic. - Removal of **poulpy-backend** & spqlios backend. - Addition of individual crates for each specific backend. - Some minor bug fixes. @@ -28,7 +119,7 @@ - Improved Cmux speed ### `poulpy-cpu-ref` -- A new crate that provides the refernce CPU implementation of **poulpy-hal**. This replaces the previous **poulpy-backend/cpu_ref**. +- A new crate that provides the reference CPU implementation of **poulpy-hal**. This replaces the previous **poulpy-backend/cpu_ref**. ### `poulpy-cpu-avx` - A new crate that provides an AVX/FMA accelerated CPU implementation of **poulpy-hal**. This replaces the previous **poulpy-backend/cpu_avx**. @@ -76,7 +167,7 @@ - Added functionality-based traits, which removes the need to import the low-levels traits of `poulpy-hal` and makes backend agnostic code much cleaner. For example instead of having to import each individual traits required for the encryption of a GLWE, only the trait `GLWEEncryptSk` is needed. ### `poulpy-schemes` - - Added basic framework for binary decicion circuit (BDD) arithmetic along with some operations. + - Added basic framework for binary decision circuit (BDD) arithmetic along with some operations. ## [0.2.0] - 2025-09-15 diff --git a/Cargo.lock b/Cargo.lock index ce45514..ef59c5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -370,6 +370,7 @@ dependencies = [ "poulpy-cpu-avx", "poulpy-cpu-ref", "poulpy-hal", + "rand", "rug", ] diff --git a/poulpy-core/Cargo.toml b/poulpy-core/Cargo.toml index 6ed82ab..9b2d23c 100644 --- a/poulpy-core/Cargo.toml +++ b/poulpy-core/Cargo.toml @@ -23,6 +23,7 @@ byteorder = {workspace = true} bytemuck = {workspace = true} once_cell = {workspace = true} paste = {workspace = true} +rand = {workspace = true} [[bench]] name = "external_product_glwe_fft64" diff --git a/poulpy-core/benches/external_product_glwe_fft64.rs b/poulpy-core/benches/external_product_glwe_fft64.rs index fd3804d..c2295f8 100644 --- a/poulpy-core/benches/external_product_glwe_fft64.rs +++ b/poulpy-core/benches/external_product_glwe_fft64.rs @@ -87,22 +87,9 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let mut sk_dft: GLWESecretPrepared, BackendImpl> = GLWESecretPrepared::alloc(&module, rank); sk_dft.prepare(&module, &sk); - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_ggsw.encrypt_sk(&module, &pt_rgsw, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); - ct_glwe_in.encrypt_zero_sk( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_glwe_in.encrypt_zero_sk(&module, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); let mut ggsw_prepared: GGSWPrepared, BackendImpl> = GGSWPrepared::alloc_from_infos(&module, &ct_ggsw); ggsw_prepared.prepare(&module, &ct_ggsw, scratch.borrow()); @@ -190,22 +177,9 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let mut sk_dft: GLWESecretPrepared, BackendImpl> = GLWESecretPrepared::alloc(&module, rank); sk_dft.prepare(&module, &sk); - ct_ggsw.encrypt_sk( - &module, - &pt_rgsw, - &sk_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_ggsw.encrypt_sk(&module, &pt_rgsw, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); - ct_glwe.encrypt_zero_sk( - &module, - &sk_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_glwe.encrypt_zero_sk(&module, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); let mut ggsw_prepared: GGSWPrepared, BackendImpl> = GGSWPrepared::alloc_from_infos(&module, &ct_ggsw); ggsw_prepared.prepare(&module, &ct_ggsw, scratch.borrow()); diff --git a/poulpy-core/benches/keyswitch_glwe_fft64.rs b/poulpy-core/benches/keyswitch_glwe_fft64.rs index afd64b8..1712bda 100644 --- a/poulpy-core/benches/keyswitch_glwe_fft64.rs +++ b/poulpy-core/benches/keyswitch_glwe_fft64.rs @@ -75,12 +75,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWESwitchingKey::encrypt_sk_tmp_bytes(&module, &gglwe_atk_layout) | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_in_layout) - | GLWE::keyswitch_tmp_bytes( - &module, - &glwe_out_layout, - &glwe_in_layout, - &gglwe_atk_layout, - ), + | GLWE::keyswitch_tmp_bytes(&module, &glwe_out_layout, &glwe_in_layout, &gglwe_atk_layout), ); let mut source_xs: Source = Source::new([0u8; 32]); @@ -93,22 +88,9 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut sk_in_dft: GLWESecretPrepared, BackendImpl> = GLWESecretPrepared::alloc(&module, rank); sk_in_dft.prepare(&module, &sk_in); - ksk.encrypt_sk( - &module, - -1, - &sk_in, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ksk.encrypt_sk(&module, -1, &sk_in, &mut source_xa, &mut source_xe, scratch.borrow()); - ct_in.encrypt_zero_sk( - &module, - &sk_in_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_in.encrypt_zero_sk(&module, &sk_in_dft, &mut source_xa, &mut source_xe, scratch.borrow()); let mut ksk_prepared: GLWEAutomorphismKeyPrepared, _> = GLWEAutomorphismKeyPrepared::alloc_from_infos(&module, &ksk); @@ -206,22 +188,9 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut sk_out: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_layout); sk_out.fill_ternary_prob(0.5, &mut source_xs); - ksk.encrypt_sk( - &module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ksk.encrypt_sk(&module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); - ct.encrypt_zero_sk( - &module, - &sk_in_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct.encrypt_zero_sk(&module, &sk_in_dft, &mut source_xa, &mut source_xe, scratch.borrow()); let mut ksk_prepared: GLWESwitchingKeyPrepared, _> = GLWESwitchingKeyPrepared::alloc_from_infos(&module, &ksk); ksk_prepared.prepare(&module, &ksk, scratch.borrow()); @@ -249,9 +218,5 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { group.finish(); } -criterion_group!( - benches, - bench_keyswitch_glwe_fft64, - bench_keyswitch_glwe_inplace_fft64 -); +criterion_group!(benches, bench_keyswitch_glwe_fft64, bench_keyswitch_glwe_inplace_fft64); criterion_main!(benches); diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 1e4dd92..19c126f 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -75,13 +75,7 @@ where a.dnum() ); - assert_eq!( - res.dsize(), - a.dsize(), - "res dnum: {} != a dnum: {}", - res.dsize(), - a.dsize() - ); + assert_eq!(res.dsize(), a.dsize(), "res dnum: {} != a dnum: {}", res.dsize(), a.dsize()); assert_eq!(res.base2k(), a.base2k()); @@ -139,13 +133,7 @@ where K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, Scratch: ScratchTakeCore, { - assert_eq!( - res.rank(), - key.rank(), - "key rank: {} != key rank: {}", - res.rank(), - key.rank() - ); + assert_eq!(res.rank(), key.rank(), "key rank: {} != key rank: {}", res.rank(), key.rank()); let cols_out: usize = (key.rank_out() + 1).into(); let cols_in: usize = key.rank_in().into(); diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index ff9dfcd..3f7d49a 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -218,13 +218,13 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let base2k_a: usize = a.base2k().into(); - let base2k_key: usize = key.base2k().into(); - let base2k_res: usize = res.base2k().into(); + let a_base2k: usize = a.base2k().into(); + let key_base2k: usize = key.base2k().into(); + let res_base2k: usize = res.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - if base2k_a != base2k_key { + if a_base2k != key_base2k { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: a.n(), base2k: key.base2k(), @@ -236,30 +236,14 @@ where for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_add_small_inplace(&mut res_big, i, a_conv.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_2, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2); } } else { let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1); } }; } @@ -272,12 +256,12 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let base2k_key: usize = key.base2k().into(); - let base2k_res: usize = res.base2k().into(); + let key_base2k: usize = key.base2k().into(); + let res_base2k: usize = res.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - if base2k_res != base2k_key { + if res_base2k != key_base2k { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: res.n(), base2k: key.base2k(), @@ -289,30 +273,14 @@ where for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_add_small_inplace(&mut res_big, i, res_conv.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_2, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2); } } else { let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1); } }; } @@ -327,13 +295,13 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let base2k_a: usize = a.base2k().into(); - let base2k_key: usize = key.base2k().into(); - let base2k_res: usize = res.base2k().into(); + let a_base2k: usize = a.base2k().into(); + let key_base2k: usize = key.base2k().into(); + let res_base2k: usize = res.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - if base2k_a != base2k_key { + if a_base2k != key_base2k { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: a.n(), base2k: key.base2k(), @@ -345,30 +313,14 @@ where for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_sub_small_inplace(&mut res_big, i, a_conv.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_2, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2); } } else { let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1); } }; } @@ -383,13 +335,13 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let base2k_a: usize = a.base2k().into(); - let base2k_key: usize = key.base2k().into(); - let base2k_res: usize = res.base2k().into(); + let a_base2k: usize = a.base2k().into(); + let key_base2k: usize = key.base2k().into(); + let res_base2k: usize = res.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - if base2k_a != base2k_key { + if a_base2k != key_base2k { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: a.n(), base2k: key.base2k(), @@ -401,30 +353,14 @@ where for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a_conv.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_2, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2); } } else { let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1); } }; } @@ -437,12 +373,12 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let base2k_key: usize = key.base2k().into(); - let base2k_res: usize = res.base2k().into(); + let key_base2k: usize = key.base2k().into(); + let res_base2k: usize = res.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - if base2k_res != base2k_key { + if res_base2k != key_base2k { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: res.n(), base2k: key.base2k(), @@ -454,30 +390,14 @@ where for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_sub_small_inplace(&mut res_big, i, res_conv.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_2, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2); } } else { let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1); } }; } @@ -490,12 +410,12 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let base2k_key: usize = key.base2k().into(); - let base2k_res: usize = res.base2k().into(); + let key_base2k: usize = key.base2k().into(); + let res_base2k: usize = res.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - if base2k_res != base2k_key { + if res_base2k != key_base2k { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: res.n(), base2k: key.base2k(), @@ -507,30 +427,14 @@ where for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res_conv.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_2, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2); } } else { let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i); - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1); } }; } diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs index 9770f1d..896d964 100644 --- a/poulpy-core/src/conversion/gglwe_to_ggsw.rs +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -120,13 +120,13 @@ where R: GGSWInfos, A: GGLWEInfos, { - let base2k_tsk: usize = tsk_infos.base2k().into(); + let tsk_base2k: usize = tsk_infos.base2k().into(); let rank: usize = res_infos.rank().into(); let cols: usize = rank + 1; let res_size: usize = res_infos.size(); - let a_size: usize = res_infos.max_k().as_usize().div_ceil(base2k_tsk); + let a_size: usize = res_infos.max_k().as_usize().div_ceil(tsk_base2k); let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size); let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size); @@ -146,15 +146,15 @@ where let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); - let base2k_res: usize = res.base2k().into(); - let base2k_tsk: usize = tsk.base2k().into(); + let res_base2k: usize = res.base2k().into(); + let tsk_base2k: usize = tsk.base2k().into(); assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); let rank: usize = res.rank().into(); let cols: usize = rank + 1; - let res_conv_size: usize = res.max_k().as_usize().div_ceil(base2k_tsk); + let res_conv_size: usize = res.max_k().as_usize().div_ceil(tsk_base2k); let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size); let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size); @@ -163,33 +163,17 @@ where for row in 0..res.dnum().as_usize() { let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0); - if base2k_res == base2k_tsk { + if res_base2k == tsk_base2k { for col_i in 0..cols - 1 { self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1); } self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0); } else { for i in 0..cols - 1 { - self.vec_znx_normalize( - base2k_tsk, - &mut a_0, - 0, - base2k_res, - glwe_mi_1.data(), - i + 1, - scratch_2, - ); + self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, i + 1, scratch_2); self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0); } - self.vec_znx_normalize( - base2k_tsk, - &mut a_0, - 0, - base2k_res, - glwe_mi_1.data(), - 0, - scratch_2, - ); + self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, 0, scratch_2); } ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2) @@ -267,13 +251,16 @@ fn ggsw_expand_rows_internal( // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0); + let res_base2k: usize = res.base2k().as_usize(); + for j in 0..cols { module.vec_znx_big_normalize( - res.base2k().as_usize(), res.at_mut(row, col).data_mut(), + res_base2k, + 0, j, - tsk.base2k().as_usize(), &res_big, + tsk.base2k().as_usize(), j, scratch_1, ); diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index 96dc65b..b97935d 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -56,12 +56,8 @@ where rank: Rank(1), }; - GLWE::bytes_of( - self.n().into(), - lwe_infos.base2k(), - lwe_infos.k(), - 1u32.into(), - ) + GLWE::bytes_of_from_infos(glwe_infos) + GLWE::bytes_of(self.n().into(), lwe_infos.base2k(), lwe_infos.k(), 1u32.into()) + + GLWE::bytes_of_from_infos(glwe_infos) + self.glwe_keyswitch_tmp_bytes(&res_infos, glwe_infos, key_infos) } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index c9b40c3..b2a7020 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -73,11 +73,12 @@ where } self.vec_znx_normalize( - ksk.base2k().into(), &mut glwe.data, + ksk.base2k().into(), + 0, 0, - lwe.base2k().into(), &a_conv, + lwe.base2k().into(), 0, scratch_2, ); @@ -89,11 +90,12 @@ where } self.vec_znx_normalize( - ksk.base2k().into(), &mut glwe.data, + ksk.base2k().into(), + 0, 1, - lwe.base2k().into(), &a_conv, + lwe.base2k().into(), 0, scratch_2, ); diff --git a/poulpy-core/src/decryption/glwe.rs b/poulpy-core/src/decryption/glwe.rs index 6dc7f5a..f558fe4 100644 --- a/poulpy-core/src/decryption/glwe.rs +++ b/poulpy-core/src/decryption/glwe.rs @@ -33,10 +33,21 @@ impl GLWE { } } -pub trait GLWEDecrypt +pub trait GLWEDecrypt { + fn glwe_decrypt_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos; + + fn glwe_decrypt(&self, res: &R, pt: &mut P, sk: &S, scratch: &mut Scratch) + where + R: GLWEToRef, + P: GLWEPlaintextToMut, + S: GLWESecretPreparedToRef; +} + +impl GLWEDecrypt for Module where - Self: Sized - + ModuleN + Self: ModuleN + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf @@ -46,6 +57,7 @@ where + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, + Scratch: ScratchTakeBasic, { fn glwe_decrypt_tmp_bytes(&self, infos: &A) -> usize where @@ -60,7 +72,6 @@ where R: GLWEToRef, P: GLWEPlaintextToMut, S: GLWESecretPreparedToRef, - Scratch: ScratchTakeBasic, { let res: &GLWE<&[u8]> = &res.to_ref(); let pt: &mut GLWEPlaintext<&mut [u8]> = &mut pt.to_ref(); @@ -94,32 +105,12 @@ where // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) self.vec_znx_big_add_small_inplace(&mut c0_big, 0, &res.data, 0); + let res_base2k: usize = res.base2k().into(); + // pt = norm(BIG(m + e)) - self.vec_znx_big_normalize( - res.base2k().into(), - &mut pt.data, - 0, - res.base2k().into(), - &c0_big, - 0, - scratch_1, - ); + self.vec_znx_big_normalize(&mut pt.data, res_base2k, 0, 0, &c0_big, res_base2k, 0, scratch_1); pt.base2k = res.base2k(); pt.k = pt.k().min(res.k()); } } - -impl GLWEDecrypt for Module where - Self: ModuleN - + VecZnxDftBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxBigBytesOf - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize -{ -} diff --git a/poulpy-core/src/decryption/glwe_tensor.rs b/poulpy-core/src/decryption/glwe_tensor.rs new file mode 100644 index 0000000..2457287 --- /dev/null +++ b/poulpy-core/src/decryption/glwe_tensor.rs @@ -0,0 +1,97 @@ +use poulpy_hal::{ + api::{ScratchTakeBasic, SvpPPolBytesOf}, + layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, +}; + +use crate::{ + GLWEDecrypt, ScratchTakeCore, + layouts::{GLWEInfos, GLWEPlaintext, GLWESecretPrepared, GLWESecretTensor, GLWESecretTensorPrepared, GLWETensor}, +}; + +impl GLWETensor> { + pub fn decrypt_tmp_bytes(module: &M, a_infos: &A) -> usize + where + A: GLWEInfos, + M: GLWETensorDecrypt, + { + module.glwe_tensor_decrypt_tmp_bytes(a_infos) + } +} + +impl GLWETensor { + pub fn decrypt( + &self, + module: &M, + pt: &mut GLWEPlaintext

, + sk: &GLWESecretPrepared, + sk_tensor: &GLWESecretTensorPrepared, + scratch: &mut Scratch, + ) where + P: DataMut, + S0: DataRef, + S1: DataRef, + M: GLWETensorDecrypt, + Scratch: ScratchTakeBasic, + { + module.glwe_tensor_decrypt(self, pt, sk, sk_tensor, scratch); + } +} + +pub trait GLWETensorDecrypt { + fn glwe_tensor_decrypt_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos; + fn glwe_tensor_decrypt( + &self, + res: &GLWETensor, + pt: &mut GLWEPlaintext

, + sk: &GLWESecretPrepared, + sk_tensor: &GLWESecretTensorPrepared, + scratch: &mut Scratch, + ) where + R: DataRef, + P: DataMut, + S0: DataRef, + S1: DataRef; +} + +impl GLWETensorDecrypt for Module +where + Self: GLWEDecrypt + SvpPPolBytesOf, + Scratch: ScratchTakeCore, +{ + fn glwe_tensor_decrypt_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + self.glwe_decrypt_tmp_bytes(infos) + } + + fn glwe_tensor_decrypt( + &self, + res: &GLWETensor, + pt: &mut GLWEPlaintext

, + sk: &GLWESecretPrepared, + sk_tensor: &GLWESecretTensorPrepared, + scratch: &mut Scratch, + ) where + R: DataRef, + P: DataMut, + S0: DataRef, + S1: DataRef, + { + let rank: usize = sk.rank().as_usize(); + + let (mut sk_grouped, scratch_1) = scratch.take_glwe_secret_prepared(self, (GLWESecretTensor::pairs(rank) + rank).into()); + + for i in 0..rank { + sk_grouped.data.at_mut(i, 0).copy_from_slice(sk.data.at(i, 0)); + } + + for i in 0..sk_grouped.rank().as_usize() - rank { + sk_grouped.data.at_mut(i + rank, 0).copy_from_slice(sk_tensor.data.at(i, 0)); + } + + self.glwe_decrypt(res, pt, &sk_grouped, scratch_1); + } +} diff --git a/poulpy-core/src/decryption/mod.rs b/poulpy-core/src/decryption/mod.rs index 4266117..10a55ec 100644 --- a/poulpy-core/src/decryption/mod.rs +++ b/poulpy-core/src/decryption/mod.rs @@ -1,5 +1,7 @@ mod glwe; +mod glwe_tensor; mod lwe; pub use glwe::*; +pub use glwe_tensor::*; pub use lwe::*; diff --git a/poulpy-core/src/dist.rs b/poulpy-core/src/dist.rs index c754278..c64dc87 100644 --- a/poulpy-core/src/dist.rs +++ b/poulpy-core/src/dist.rs @@ -63,10 +63,7 @@ impl Distribution { TAG_ZERO => Distribution::ZERO, TAG_NONE => Distribution::NONE, _ => { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "Invalid tag", - )); + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid tag")); } }; Ok(dist) diff --git a/poulpy-core/src/encryption/compressed/gglwe.rs b/poulpy-core/src/encryption/compressed/gglwe.rs index acf2eac..9e875da 100644 --- a/poulpy-core/src/encryption/compressed/gglwe.rs +++ b/poulpy-core/src/encryption/compressed/gglwe.rs @@ -77,9 +77,7 @@ where where A: GGLWEInfos, { - self.glwe_encrypt_sk_tmp_bytes(infos) - .max(self.vec_znx_normalize_tmp_bytes()) - + GLWEPlaintext::bytes_of_from_infos(infos) + self.glwe_encrypt_sk_tmp_bytes(infos).max(self.vec_znx_normalize_tmp_bytes()) + GLWEPlaintext::bytes_of_from_infos(infos) } fn gglwe_compressed_encrypt_sk( diff --git a/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs b/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs index 92f382c..49e884e 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs @@ -101,24 +101,13 @@ where for i in 0..rank { for j in 0..rank { - self.vec_znx_copy( - &mut sk_ij.as_vec_znx_mut(), - j, - &sk_tensor.at(i, j).as_vec_znx(), - 0, - ); + self.vec_znx_copy(&mut sk_ij.as_vec_znx_mut(), j, &sk_tensor.at(i, j).as_vec_znx(), 0); } let (seed_xa_tmp, _) = source_xa.branch(); - res.at_mut(i).encrypt_sk( - self, - &sk_ij, - &sk_prepared, - seed_xa_tmp, - source_xe, - scratch_3, - ); + res.at_mut(i) + .encrypt_sk(self, &sk_ij, &sk_prepared, seed_xa_tmp, source_xe, scratch_3); } } } diff --git a/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs b/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs index 0c15e96..1f0e9ec 100644 --- a/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs +++ b/poulpy-core/src/encryption/compressed/glwe_automorphism_key.rs @@ -112,14 +112,7 @@ where sk_out_prepared.prepare(self, &sk_out); } - self.gglwe_compressed_encrypt_sk( - res, - &sk.data, - &sk_out_prepared, - seed_xa, - source_xe, - scratch_1, - ); + self.gglwe_compressed_encrypt_sk(res, &sk.data, &sk_out_prepared, seed_xa, source_xe, scratch_1); res.set_p(p); } diff --git a/poulpy-core/src/encryption/compressed/glwe_switching_key.rs b/poulpy-core/src/encryption/compressed/glwe_switching_key.rs index b801927..ac62426 100644 --- a/poulpy-core/src/encryption/compressed/glwe_switching_key.rs +++ b/poulpy-core/src/encryption/compressed/glwe_switching_key.rs @@ -104,12 +104,7 @@ where let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into()); for i in 0..sk_in.rank().into() { - self.vec_znx_switch_ring( - &mut sk_in_tmp.as_vec_znx_mut(), - i, - &sk_in.data.as_vec_znx(), - i, - ); + self.vec_znx_switch_ring(&mut sk_in_tmp.as_vec_znx_mut(), i, &sk_in.data.as_vec_znx(), i); } let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank()); diff --git a/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs index 12af7ee..0678cb3 100644 --- a/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs +++ b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs @@ -102,13 +102,6 @@ where sk_prepared.prepare(self, sk); sk_tensor.prepare(self, sk, scratch_2); - self.gglwe_compressed_encrypt_sk( - res, - &sk_tensor.data, - &sk_prepared, - seed_xa, - source_xe, - scratch_2, - ); + self.gglwe_compressed_encrypt_sk(res, &sk_tensor.data, &sk_prepared, seed_xa, source_xe, scratch_2); } } diff --git a/poulpy-core/src/encryption/gglwe.rs b/poulpy-core/src/encryption/gglwe.rs index a50b565..5de1117 100644 --- a/poulpy-core/src/encryption/gglwe.rs +++ b/poulpy-core/src/encryption/gglwe.rs @@ -160,14 +160,7 @@ where tmp_pt.data.zero(); // zeroes for next iteration self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i); self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); - self.glwe_encrypt_sk( - &mut res.at_mut(row_i, col_i), - &tmp_pt, - sk, - source_xa, - source_xe, - scrach_1, - ); + self.glwe_encrypt_sk(&mut res.at_mut(row_i, col_i), &tmp_pt, sk, source_xa, source_xe, scrach_1); } } } diff --git a/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs b/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs index 017455f..0f5b874 100644 --- a/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs +++ b/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs @@ -97,12 +97,7 @@ where for i in 0..rank { for j in 0..rank { - self.vec_znx_copy( - &mut sk_ij.as_vec_znx_mut(), - j, - &sk_tensor.at(i, j).as_vec_znx(), - 0, - ); + self.vec_znx_copy(&mut sk_ij.as_vec_znx_mut(), j, &sk_tensor.at(i, j).as_vec_znx(), 0); } res.at_mut(i) diff --git a/poulpy-core/src/encryption/ggsw.rs b/poulpy-core/src/encryption/ggsw.rs index 86b810c..1442f78 100644 --- a/poulpy-core/src/encryption/ggsw.rs +++ b/poulpy-core/src/encryption/ggsw.rs @@ -76,9 +76,7 @@ where where A: GGSWInfos, { - self.glwe_encrypt_sk_tmp_bytes(infos) - .max(self.vec_znx_normalize_tmp_bytes()) - + GLWEPlaintext::bytes_of_from_infos(infos) + self.glwe_encrypt_sk_tmp_bytes(infos).max(self.vec_znx_normalize_tmp_bytes()) + GLWEPlaintext::bytes_of_from_infos(infos) } fn ggsw_encrypt_sk( diff --git a/poulpy-core/src/encryption/glwe.rs b/poulpy-core/src/encryption/glwe.rs index e74403c..06b6426 100644 --- a/poulpy-core/src/encryption/glwe.rs +++ b/poulpy-core/src/encryption/glwe.rs @@ -402,15 +402,7 @@ where let mut ci_big = self.vec_znx_idft_apply_consume(ci_dft); // ci_big = u * pk[i] + e - self.vec_znx_big_add_normal( - base2k, - &mut ci_big, - 0, - pk.k().into(), - source_xe, - SIGMA, - SIGMA_BOUND, - ); + self.vec_znx_big_add_normal(base2k, &mut ci_big, 0, pk.k().into(), source_xe, SIGMA, SIGMA_BOUND); // ci_big = u * pk[i] + e + m (if col = i) if let Some((pt, col)) = pt @@ -420,7 +412,7 @@ where } // ct[i] = norm(ci_big) - self.vec_znx_big_normalize(base2k, &mut res.data, i, base2k, &ci_big, 0, scratch_2); + self.vec_znx_big_normalize(&mut res.data, base2k, 0, i, &ci_big, base2k, 0, scratch_2); } } } @@ -487,12 +479,7 @@ where let sk: GLWESecretPrepared<&[u8], BE> = sk.to_ref(); if compressed { - assert_eq!( - ct.cols(), - 1, - "invalid glwe: compressed tag=true but #cols={} != 1", - ct.cols() - ) + assert_eq!(ct.cols(), 1, "invalid glwe: compressed tag=true but #cols={} != 1", ct.cols()) } assert!( @@ -537,7 +524,7 @@ where let ci_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(ci_dft); // use c[0] as buffer, which is overwritten later by the normalization step - self.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); + self.vec_znx_big_normalize(&mut ci, base2k, 0, 0, &ci_big, base2k, 0, scratch_3); // c0_tmp = -c[i] * s[i] (use c[0] as buffer) self.vec_znx_sub_inplace(&mut c0, 0, &ci, 0); @@ -555,6 +542,6 @@ where } // c[0] = norm(c[0]) - self.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1); + self.vec_znx_normalize(ct, base2k, 0, 0, &c0, base2k, 0, scratch_1); } } diff --git a/poulpy-core/src/encryption/glwe_automorphism_key.rs b/poulpy-core/src/encryption/glwe_automorphism_key.rs index 4feaeb6..deba08b 100644 --- a/poulpy-core/src/encryption/glwe_automorphism_key.rs +++ b/poulpy-core/src/encryption/glwe_automorphism_key.rs @@ -130,14 +130,7 @@ where sk_out_prepared.prepare(self, &sk_out); } - self.gglwe_encrypt_sk( - res, - &sk.data, - &sk_out_prepared, - source_xa, - source_xe, - scratch_1, - ); + self.gglwe_encrypt_sk(res, &sk.data, &sk_out_prepared, source_xa, source_xe, scratch_1); res.set_p(p); } diff --git a/poulpy-core/src/encryption/glwe_switching_key.rs b/poulpy-core/src/encryption/glwe_switching_key.rs index e1467f3..3dd475d 100644 --- a/poulpy-core/src/encryption/glwe_switching_key.rs +++ b/poulpy-core/src/encryption/glwe_switching_key.rs @@ -78,8 +78,7 @@ where where A: GGLWEInfos, { - self.gglwe_encrypt_sk_tmp_bytes(infos) - .max(ScalarZnx::bytes_of(self.n(), 1)) + self.gglwe_encrypt_sk_tmp_bytes(infos).max(ScalarZnx::bytes_of(self.n(), 1)) + ScalarZnx::bytes_of(self.n(), infos.rank_in().into()) + self.bytes_of_glwe_secret_prepared_from_infos(infos) } @@ -111,12 +110,7 @@ where let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into()); for i in 0..sk_in.rank().into() { - self.vec_znx_switch_ring( - &mut sk_in_tmp.as_vec_znx_mut(), - i, - &sk_in.data.as_vec_znx(), - i, - ); + self.vec_znx_switch_ring(&mut sk_in_tmp.as_vec_znx_mut(), i, &sk_in.data.as_vec_znx(), i); } let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank()); @@ -130,14 +124,7 @@ where sk_out_tmp.dist = sk_out.dist; - self.gglwe_encrypt_sk( - res, - &sk_in_tmp, - &sk_out_tmp, - source_xa, - source_xe, - scratch_2, - ); + self.gglwe_encrypt_sk(res, &sk_in_tmp, &sk_out_tmp, source_xa, source_xe, scratch_2); *res.input_degree() = sk_in.n(); *res.output_degree() = sk_out.n(); diff --git a/poulpy-core/src/encryption/glwe_tensor_key.rs b/poulpy-core/src/encryption/glwe_tensor_key.rs index 08df09b..344b139 100644 --- a/poulpy-core/src/encryption/glwe_tensor_key.rs +++ b/poulpy-core/src/encryption/glwe_tensor_key.rs @@ -103,13 +103,6 @@ where sk_prepared.prepare(self, sk); sk_tensor.prepare(self, sk, scratch_2); - self.gglwe_encrypt_sk( - res, - &sk_tensor.data, - &sk_prepared, - source_xa, - source_xe, - scratch_2, - ); + self.gglwe_encrypt_sk(res, &sk_tensor.data, &sk_prepared, source_xa, source_xe, scratch_2); } } diff --git a/poulpy-core/src/encryption/glwe_to_lwe_key.rs b/poulpy-core/src/encryption/glwe_to_lwe_key.rs index 0609fb6..4a1f39d 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_key.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_key.rs @@ -107,13 +107,6 @@ where sk_lwe_as_glwe_prep.prepare(self, &sk_lwe_as_glwe); } - self.gglwe_encrypt_sk( - res, - &sk_glwe.data, - &sk_lwe_as_glwe_prep, - source_xa, - source_xe, - scratch_1, - ); + self.gglwe_encrypt_sk(res, &sk_glwe.data, &sk_lwe_as_glwe_prep, source_xa, source_xe, scratch_1); } } diff --git a/poulpy-core/src/encryption/lwe_switching_key.rs b/poulpy-core/src/encryption/lwe_switching_key.rs index 431545b..ca0c0e9 100644 --- a/poulpy-core/src/encryption/lwe_switching_key.rs +++ b/poulpy-core/src/encryption/lwe_switching_key.rs @@ -70,21 +70,9 @@ where where A: GGLWEInfos, { - assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); + assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey"); + assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey"); + assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey"); GLWESecret::bytes_of(self.n().into(), Rank(1)) + GLWESecretPrepared::bytes_of(self, Rank(1)) + GLWESwitchingKey::encrypt_sk_tmp_bytes(self, infos) @@ -125,13 +113,6 @@ where sk_glwe_in.data.at_mut(0, 0)[sk_lwe_in.n().into()..].fill(0); self.vec_znx_automorphism_inplace(-1, &mut sk_glwe_in.data.as_vec_znx_mut(), 0, scratch_2); - self.glwe_switching_key_encrypt_sk( - res, - &sk_glwe_in, - &sk_glwe_out, - source_xa, - source_xe, - scratch_2, - ); + self.glwe_switching_key_encrypt_sk(res, &sk_glwe_in, &sk_glwe_out, source_xa, source_xe, scratch_2); } } diff --git a/poulpy-core/src/encryption/lwe_to_glwe_key.rs b/poulpy-core/src/encryption/lwe_to_glwe_key.rs index c5fcd15..44387ac 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_key.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_key.rs @@ -106,13 +106,6 @@ where sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n().into()..].fill(0); self.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); - self.gglwe_encrypt_sk( - res, - &sk_lwe_as_glwe.data, - sk_glwe, - source_xa, - source_xe, - scratch_1, - ); + self.gglwe_encrypt_sk(res, &sk_lwe_as_glwe.data, sk_glwe, source_xa, source_xe, scratch_1); } } diff --git a/poulpy-core/src/external_product/ggsw.rs b/poulpy-core/src/external_product/ggsw.rs index 92ff84b..b5a474b 100644 --- a/poulpy-core/src/external_product/ggsw.rs +++ b/poulpy-core/src/external_product/ggsw.rs @@ -35,20 +35,8 @@ where let a: &GGSW<&[u8]> = &a.to_ref(); let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); - assert_eq!( - res.rank(), - a.rank(), - "res rank: {} != a rank: {}", - res.rank(), - a.rank() - ); - assert_eq!( - res.rank(), - b.rank(), - "res rank: {} != b rank: {}", - res.rank(), - b.rank() - ); + assert_eq!(res.rank(), a.rank(), "res rank: {} != a rank: {}", res.rank(), a.rank()); + assert_eq!(res.rank(), b.rank(), "res rank: {} != b rank: {}", res.rank(), b.rank()); assert_eq!(res.base2k(), a.base2k()); @@ -80,13 +68,7 @@ where assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); - assert_eq!( - res.rank(), - a.rank(), - "res rank: {} != a rank: {}", - res.rank(), - a.rank() - ); + assert_eq!(res.rank(), a.rank(), "res rank: {} != a rank: {}", res.rank(), a.rank()); for row in 0..res.dnum().into() { for col in 0..(res.rank() + 1).into() { diff --git a/poulpy-core/src/external_product/glwe.rs b/poulpy-core/src/external_product/glwe.rs index ef85998..c144ca3 100644 --- a/poulpy-core/src/external_product/glwe.rs +++ b/poulpy-core/src/external_product/glwe.rs @@ -110,12 +110,12 @@ where assert_eq!(ggsw.n(), res.n()); assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, ggsw)); - let base2k_res: usize = res.base2k().as_usize(); - let base2k_ggsw: usize = ggsw.base2k().as_usize(); + let res_base2k: usize = res.base2k().as_usize(); + let ggsw_base2k: usize = ggsw.base2k().as_usize(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_ggsw { + let res_big: VecZnxBig<&mut [u8], BE> = if res_base2k != ggsw_base2k { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: res.n(), base2k: ggsw.base2k(), @@ -130,15 +130,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - j, - base2k_ggsw, - &res_big, - j, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, j, &res_big, ggsw_base2k, j, scratch_1); } } @@ -155,13 +147,13 @@ where assert_eq!(a.n(), res.n()); assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, a, ggsw)); - let base2k_a: usize = a.base2k().into(); - let base2k_ggsw: usize = ggsw.base2k().into(); - let base2k_res: usize = res.base2k().into(); + let a_base2k: usize = a.base2k().into(); + let ggsw_base2k: usize = ggsw.base2k().into(); + let res_base2k: usize = res.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_ggsw { + let res_big: VecZnxBig<&mut [u8], BE> = if a_base2k != ggsw_base2k { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: a.n(), base2k: ggsw.base2k(), @@ -176,15 +168,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - j, - base2k_ggsw, - &res_big, - j, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, j, &res_big, ggsw_base2k, j, scratch_1); } } } @@ -231,10 +215,7 @@ where A: GLWEInfos, B: GGSWInfos, { - let in_size: usize = a_infos - .k() - .div_ceil(b_infos.base2k()) - .div_ceil(b_infos.dsize().into()) as usize; + let in_size: usize = a_infos.k().div_ceil(b_infos.base2k()).div_ceil(b_infos.dsize().into()) as usize; let out_size: usize = res_infos.size(); let ggsw_size: usize = b_infos.size(); let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size); diff --git a/poulpy-core/src/glwe_packer.rs b/poulpy-core/src/glwe_packer.rs index b0b9595..0826883 100644 --- a/poulpy-core/src/glwe_packer.rs +++ b/poulpy-core/src/glwe_packer.rs @@ -265,23 +265,9 @@ fn pack_core( // Propagates to next accumulator if acc_prev[0].value { - pack_core( - module, - Some(&acc_prev[0].data), - acc_next, - i + 1, - auto_keys, - scratch, - ); + pack_core(module, Some(&acc_prev[0].data), acc_next, i + 1, auto_keys, scratch); } else { - pack_core( - module, - None::<&GLWE>>, - acc_next, - i + 1, - auto_keys, - scratch, - ); + pack_core(module, None::<&GLWE>>, acc_next, i + 1, auto_keys, scratch); } } } @@ -319,11 +305,7 @@ fn combine( let log_n: usize = acc.data.n().log2(); let a: &mut GLWE> = &mut acc.data; - let gal_el: i64 = if i == 0 { - -1 - } else { - module.galois_element(1 << (i - 1)) - }; + let gal_el: i64 = if i == 0 { -1 } else { module.galois_element(1 << (i - 1)) }; let t: i64 = 1 << (log_n - i - 1); diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 0626c55..f5100bf 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -88,8 +88,7 @@ where let key: &K = if i == 0 { keys.get_automorphism_key(-1).unwrap() } else { - keys.get_automorphism_key(self.galois_element(1 << (i - 1))) - .unwrap() + keys.get_automorphism_key(self.galois_element(1 << (i - 1))).unwrap() }; for j in 0..t { diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 3a27428..9b3a08b 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -169,11 +169,7 @@ where for i in skip..log_n { self.glwe_rsh(1, res, scratch); - let p: i64 = if i == 0 { - -1 - } else { - self.galois_element(1 << (i - 1)) - }; + let p: i64 = if i == 0 { -1 } else { self.galois_element(1 << (i - 1)) }; if let Some(key) = keys.get_automorphism_key(p) { self.glwe_automorphism_add_inplace(res, key, scratch); diff --git a/poulpy-core/src/keyswitching/gglwe.rs b/poulpy-core/src/keyswitching/gglwe.rs index 779ff11..03be4f7 100644 --- a/poulpy-core/src/keyswitching/gglwe.rs +++ b/poulpy-core/src/keyswitching/gglwe.rs @@ -148,19 +148,8 @@ where res.rank_out(), b.rank_out() ); - assert!( - res.dnum() <= a.dnum(), - "res.dnum()={} > a.dnum()={}", - res.dnum(), - a.dnum() - ); - assert_eq!( - res.dsize(), - a.dsize(), - "res dsize: {} != a dsize: {}", - res.dsize(), - a.dsize() - ); + assert!(res.dnum() <= a.dnum(), "res.dnum()={} > a.dnum()={}", res.dnum(), a.dnum()); + assert_eq!(res.dsize(), a.dsize(), "res dsize: {} != a dsize: {}", res.dsize(), a.dsize()); assert_eq!(res.base2k(), a.base2k()); let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); diff --git a/poulpy-core/src/keyswitching/glwe.rs b/poulpy-core/src/keyswitching/glwe.rs index 5fe298e..11c57ef 100644 --- a/poulpy-core/src/keyswitching/glwe.rs +++ b/poulpy-core/src/keyswitching/glwe.rs @@ -105,13 +105,13 @@ where scratch.available(), ); - let base2k_a: usize = a.base2k().into(); - let base2k_key: usize = key.base2k().into(); - let base2k_res: usize = res.base2k().into(); + let a_base2k: usize = a.base2k().into(); + let key_base2k: usize = key.base2k().into(); + let res_base2k: usize = res.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_key { + let res_big: VecZnxBig<&mut [u8], BE> = if a_base2k != key_base2k { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: a.n(), base2k: key.base2k(), @@ -126,15 +126,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); for i in 0..(res.rank() + 1).into() { - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1); } } @@ -169,12 +161,12 @@ where scratch.available(), ); - let base2k_res: usize = res.base2k().as_usize(); - let base2k_key: usize = key.base2k().as_usize(); + let res_base2k: usize = res.base2k().as_usize(); + let key_base2k: usize = key.base2k().as_usize(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_key { + let res_big: VecZnxBig<&mut [u8], BE> = if res_base2k != key_base2k { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: res.n(), base2k: key.base2k(), @@ -190,15 +182,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); for i in 0..(res.rank() + 1).into() { - self.vec_znx_big_normalize( - base2k_res, - res.data_mut(), - i, - base2k_key, - &res_big, - i, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1); } } } diff --git a/poulpy-core/src/layouts/compressed/gglwe.rs b/poulpy-core/src/layouts/compressed/gglwe.rs index bb5fe55..0b11d95 100644 --- a/poulpy-core/src/layouts/compressed/gglwe.rs +++ b/poulpy-core/src/layouts/compressed/gglwe.rs @@ -137,13 +137,7 @@ impl GGLWECompressed> { ); GGLWECompressed { - data: MatZnx::alloc( - n.into(), - dnum.into(), - rank_in.into(), - 1, - k.0.div_ceil(base2k.0) as usize, - ), + data: MatZnx::alloc(n.into(), dnum.into(), rank_in.into(), 1, k.0.div_ceil(base2k.0) as usize), k, base2k, dsize, @@ -181,13 +175,7 @@ impl GGLWECompressed> { dsize.0, ); - MatZnx::bytes_of( - n.into(), - dnum.into(), - rank_in.into(), - 1, - k.0.div_ceil(base2k.0) as usize, - ) + MatZnx::bytes_of(n.into(), dnum.into(), rank_in.into(), 1, k.0.div_ceil(base2k.0) as usize) } } diff --git a/poulpy-core/src/layouts/compressed/ggsw.rs b/poulpy-core/src/layouts/compressed/ggsw.rs index 14fdd7a..a25a2ae 100644 --- a/poulpy-core/src/layouts/compressed/ggsw.rs +++ b/poulpy-core/src/layouts/compressed/ggsw.rs @@ -127,13 +127,7 @@ impl GGSWCompressed> { ); GGSWCompressed { - data: MatZnx::alloc( - n.into(), - dnum.into(), - (rank + 1).into(), - 1, - k.0.div_ceil(base2k.0) as usize, - ), + data: MatZnx::alloc(n.into(), dnum.into(), (rank + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), k, base2k, dsize, @@ -171,13 +165,7 @@ impl GGSWCompressed> { dsize.0, ); - MatZnx::bytes_of( - n.into(), - dnum.into(), - (rank + 1).into(), - 1, - k.0.div_ceil(base2k.0) as usize, - ) + MatZnx::bytes_of(n.into(), dnum.into(), (rank + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) } } diff --git a/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs b/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs index c6e9297..773efde 100644 --- a/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs +++ b/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs @@ -95,15 +95,7 @@ impl GLWETensorKeyCompressed> { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); - GLWETensorKeyCompressed(GGLWECompressed::alloc( - n, - base2k, - k, - Rank(pairs), - rank, - dnum, - dsize, - )) + GLWETensorKeyCompressed(GGLWECompressed::alloc(n, base2k, k, Rank(pairs), rank, dnum, dsize)) } pub fn bytes_of_from_infos(infos: &A) -> usize diff --git a/poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs b/poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs index 5552d11..85b5a38 100644 --- a/poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs +++ b/poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs @@ -100,13 +100,7 @@ impl GLWEToLWESwitchingKeyCompressed> { 1, "dsize > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - Self::alloc( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank_in(), - infos.dnum(), - ) + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { diff --git a/poulpy-core/src/layouts/compressed/lwe_switching_key.rs b/poulpy-core/src/layouts/compressed/lwe_switching_key.rs index 764d423..c55aab5 100644 --- a/poulpy-core/src/layouts/compressed/lwe_switching_key.rs +++ b/poulpy-core/src/layouts/compressed/lwe_switching_key.rs @@ -88,11 +88,7 @@ impl LWESwitchingKeyCompressed> { where A: GGLWEInfos, { - assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKeyCompressed" - ); + assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKeyCompressed"); assert_eq!( infos.rank_in().0, 1, @@ -122,11 +118,7 @@ impl LWESwitchingKeyCompressed> { where A: GGLWEInfos, { - assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKeyCompressed" - ); + assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKeyCompressed"); assert_eq!( infos.rank_in().0, 1, diff --git a/poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs b/poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs index 984ed05..28c1b02 100644 --- a/poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs +++ b/poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs @@ -98,13 +98,7 @@ impl LWEToGLWEKeyCompressed> { 1, "rank_in > 1 is not supported for LWEToGLWESwitchingKeyCompressed" ); - Self::alloc( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank_out(), - infos.dnum(), - ) + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { diff --git a/poulpy-core/src/layouts/glwe.rs b/poulpy-core/src/layouts/glwe.rs index ba03b0e..def9600 100644 --- a/poulpy-core/src/layouts/glwe.rs +++ b/poulpy-core/src/layouts/glwe.rs @@ -129,13 +129,7 @@ impl fmt::Debug for GLWE { impl fmt::Display for GLWE { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "GLWE: base2k={} k={}: {}", - self.base2k().0, - self.k().0, - self.data - ) + write!(f, "GLWE: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data) } } diff --git a/poulpy-core/src/layouts/glwe_plaintext.rs b/poulpy-core/src/layouts/glwe_plaintext.rs index 411617d..27aa1f9 100644 --- a/poulpy-core/src/layouts/glwe_plaintext.rs +++ b/poulpy-core/src/layouts/glwe_plaintext.rs @@ -73,13 +73,7 @@ impl GLWEInfos for GLWEPlaintext { impl fmt::Display for GLWEPlaintext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "GLWEPlaintext: base2k={} k={}: {}", - self.base2k().0, - self.k().0, - self.data - ) + write!(f, "GLWEPlaintext: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data) } } diff --git a/poulpy-core/src/layouts/glwe_secret_tensor.rs b/poulpy-core/src/layouts/glwe_secret_tensor.rs index b5eff84..5c736ed 100644 --- a/poulpy-core/src/layouts/glwe_secret_tensor.rs +++ b/poulpy-core/src/layouts/glwe_secret_tensor.rs @@ -10,7 +10,7 @@ use poulpy_hal::{ }; use crate::{ - ScratchTakeCore, + GetDistribution, ScratchTakeCore, dist::Distribution, layouts::{ Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESecretToMut, GLWESecretToRef, LWEInfos, Rank, @@ -30,6 +30,12 @@ impl GLWESecretTensor> { } } +impl GetDistribution for GLWESecretTensor { + fn dist(&self) -> &Distribution { + &self.dist + } +} + impl LWEInfos for GLWESecretTensor { fn base2k(&self) -> Base2K { Base2K(0) @@ -204,12 +210,14 @@ where let idx: usize = i * rank + j - (i * (i + 1) / 2); self.svp_apply_dft_to_dft(&mut a_ij_dft, 0, &a_prepared.data, j, &a_dft, i); self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0); + self.vec_znx_big_normalize( - base2k, &mut res.data.as_vec_znx_mut(), - idx, base2k, + 0, + idx, &a_ij_big, + base2k, 0, scratch_4, ); diff --git a/poulpy-core/src/layouts/glwe_tensor.rs b/poulpy-core/src/layouts/glwe_tensor.rs index 8ff6428..5d93563 100644 --- a/poulpy-core/src/layouts/glwe_tensor.rs +++ b/poulpy-core/src/layouts/glwe_tensor.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ source::Source, }; -use crate::layouts::{Base2K, Degree, GLWEInfos, LWEInfos, Rank, SetGLWEInfos, TorusPrecision}; +use crate::layouts::{Base2K, Degree, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, Rank, SetGLWEInfos, TorusPrecision}; use std::fmt; #[derive(PartialEq, Eq, Clone)] @@ -68,13 +68,7 @@ impl fmt::Debug for GLWETensor { impl fmt::Display for GLWETensor { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "GLWETensor: base2k={} k={}: {}", - self.base2k().0, - self.k().0, - self.data - ) + write!(f, "GLWETensor: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data) } } @@ -93,9 +87,10 @@ impl GLWETensor> { } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1); + let cols: usize = rank.as_usize() + 1; + let pairs: usize = (((cols + 1) * cols) >> 1).max(1); GLWETensor { - data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize), + data: VecZnx::alloc(n.into(), pairs, k.0.div_ceil(base2k.0) as usize), base2k, k, rank, @@ -110,36 +105,27 @@ impl GLWETensor> { } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { - let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1); - VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize) + let cols: usize = rank.as_usize() + 1; + let pairs: usize = (((cols + 1) * cols) >> 1).max(1); + VecZnx::bytes_of(n.into(), pairs, k.0.div_ceil(base2k.0) as usize) } } -pub trait GLWETensorToRef { - fn to_ref(&self) -> GLWETensor<&[u8]>; -} - -impl GLWETensorToRef for GLWETensor { - fn to_ref(&self) -> GLWETensor<&[u8]> { - GLWETensor { +impl GLWEToRef for GLWETensor { + fn to_ref(&self) -> GLWE<&[u8]> { + GLWE { k: self.k, base2k: self.base2k, data: self.data.to_ref(), - rank: self.rank, } } } -pub trait GLWETensorToMut { - fn to_mut(&mut self) -> GLWETensor<&mut [u8]>; -} - -impl GLWETensorToMut for GLWETensor { - fn to_mut(&mut self) -> GLWETensor<&mut [u8]> { - GLWETensor { +impl GLWEToMut for GLWETensor { + fn to_mut(&mut self) -> GLWE<&mut [u8]> { + GLWE { k: self.k, base2k: self.base2k, - rank: self.rank, data: self.data.to_mut(), } } diff --git a/poulpy-core/src/layouts/glwe_tensor_key.rs b/poulpy-core/src/layouts/glwe_tensor_key.rs index 032a892..b466cac 100644 --- a/poulpy-core/src/layouts/glwe_tensor_key.rs +++ b/poulpy-core/src/layouts/glwe_tensor_key.rs @@ -48,7 +48,9 @@ impl GLWEInfos for GLWETensorKey { impl GGLWEInfos for GLWETensorKey { fn rank_in(&self) -> Rank { - self.rank_out() + let rank_out: usize = self.rank_out().as_usize(); + let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); + pairs.into() } fn rank_out(&self) -> Rank { @@ -86,7 +88,9 @@ impl GLWEInfos for GLWETensorKeyLayout { impl GGLWEInfos for GLWETensorKeyLayout { fn rank_in(&self) -> Rank { - self.rank + let rank_out: usize = self.rank_out().as_usize(); + let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); + pairs.into() } fn dsize(&self) -> Dsize { @@ -127,11 +131,6 @@ impl GLWETensorKey> { where A: GGLWEInfos, { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWETensorKey" - ); Self::alloc( infos.n(), infos.base2k(), @@ -151,11 +150,6 @@ impl GLWETensorKey> { where A: GGLWEInfos, { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for GGLWETensorKey" - ); Self::bytes_of( infos.n(), infos.base2k(), diff --git a/poulpy-core/src/layouts/glwe_to_lwe_key.rs b/poulpy-core/src/layouts/glwe_to_lwe_key.rs index 65ea1cf..19757bd 100644 --- a/poulpy-core/src/layouts/glwe_to_lwe_key.rs +++ b/poulpy-core/src/layouts/glwe_to_lwe_key.rs @@ -137,58 +137,22 @@ impl GLWEToLWEKey> { where A: GGLWEInfos, { - assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for GLWEToLWEKey" - ); - assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for GLWEToLWEKey" - ); - Self::alloc( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank_in(), - infos.dnum(), - ) + assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for GLWEToLWEKey"); + assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWEKey"); + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { - GLWEToLWEKey(GLWESwitchingKey::alloc( - n, - base2k, - k, - rank_in, - Rank(1), - dnum, - Dsize(1), - )) + GLWEToLWEKey(GLWESwitchingKey::alloc(n, base2k, k, rank_in, Rank(1), dnum, Dsize(1))) } pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for GLWEToLWEKey" - ); - assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for GLWEToLWEKey" - ); - Self::bytes_of( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank_in(), - infos.dnum(), - ) + assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for GLWEToLWEKey"); + assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWEKey"); + Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { diff --git a/poulpy-core/src/layouts/lwe.rs b/poulpy-core/src/layouts/lwe.rs index ce50d54..bd2bd2e 100644 --- a/poulpy-core/src/layouts/lwe.rs +++ b/poulpy-core/src/layouts/lwe.rs @@ -96,7 +96,7 @@ impl LWE { } impl LWE { - pub fn data_mut(&mut self) -> &VecZnx { + pub fn data_mut(&mut self) -> &mut VecZnx { &mut self.data } } @@ -109,13 +109,7 @@ impl fmt::Debug for LWE { impl fmt::Display for LWE { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "LWE: base2k={} k={}: {}", - self.base2k().0, - self.k().0, - self.data - ) + write!(f, "LWE: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data) } } diff --git a/poulpy-core/src/layouts/lwe_plaintext.rs b/poulpy-core/src/layouts/lwe_plaintext.rs index 7c0d39a..f568431 100644 --- a/poulpy-core/src/layouts/lwe_plaintext.rs +++ b/poulpy-core/src/layouts/lwe_plaintext.rs @@ -71,13 +71,7 @@ impl LWEPlaintext> { impl fmt::Display for LWEPlaintext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "LWEPlaintext: base2k={} k={}: {}", - self.base2k().0, - self.k().0, - self.data - ) + write!(f, "LWEPlaintext: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data) } } diff --git a/poulpy-core/src/layouts/lwe_switching_key.rs b/poulpy-core/src/layouts/lwe_switching_key.rs index 2ae6032..391a341 100644 --- a/poulpy-core/src/layouts/lwe_switching_key.rs +++ b/poulpy-core/src/layouts/lwe_switching_key.rs @@ -106,55 +106,23 @@ impl LWESwitchingKey> { where A: GGLWEInfos, { - assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); + assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey"); + assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey"); + assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey"); Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.dnum()) } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self { - LWESwitchingKey(GLWESwitchingKey::alloc( - n, - base2k, - k, - Rank(1), - Rank(1), - dnum, - Dsize(1), - )) + LWESwitchingKey(GLWESwitchingKey::alloc(n, base2k, k, Rank(1), Rank(1), dnum, Dsize(1))) } pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); + assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey"); + assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey"); + assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey"); Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.dnum()) } diff --git a/poulpy-core/src/layouts/lwe_to_glwe_key.rs b/poulpy-core/src/layouts/lwe_to_glwe_key.rs index 5a44f61..932a818 100644 --- a/poulpy-core/src/layouts/lwe_to_glwe_key.rs +++ b/poulpy-core/src/layouts/lwe_to_glwe_key.rs @@ -136,59 +136,23 @@ impl LWEToGLWEKey> { where A: GGLWEInfos, { - assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWEKey" - ); - assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWEToGLWEKey" - ); + assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWEKey"); + assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWEKey"); - Self::alloc( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank_out(), - infos.dnum(), - ) + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - LWEToGLWEKey(GLWESwitchingKey::alloc( - n, - base2k, - k, - Rank(1), - rank_out, - dnum, - Dsize(1), - )) + LWEToGLWEKey(GLWESwitchingKey::alloc(n, base2k, k, Rank(1), rank_out, dnum, Dsize(1))) } pub fn bytes_of_from_infos(infos: &A) -> usize where A: GGLWEInfos, { - assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWEKey" - ); - assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWEToGLWEKey" - ); - Self::bytes_of( - infos.n(), - infos.base2k(), - infos.k(), - infos.rank_out(), - infos.dnum(), - ) + assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWEKey"); + assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWEKey"); + Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize { diff --git a/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs b/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs index d63dca6..fb88aa0 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs @@ -94,13 +94,7 @@ where infos.rank_out(), "rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared" ); - self.alloc_gglwe_to_ggsw_key_prepared( - infos.base2k(), - infos.k(), - infos.rank(), - infos.dnum(), - infos.dsize(), - ) + self.alloc_gglwe_to_ggsw_key_prepared(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize()) } fn alloc_gglwe_to_ggsw_key_prepared( @@ -127,13 +121,7 @@ where infos.rank_out(), "rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared" ); - self.bytes_of_gglwe_to_ggsw( - infos.base2k(), - infos.k(), - infos.rank(), - infos.dnum(), - infos.dsize(), - ) + self.bytes_of_gglwe_to_ggsw(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize()) } fn bytes_of_gglwe_to_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { diff --git a/poulpy-core/src/layouts/prepared/ggsw.rs b/poulpy-core/src/layouts/prepared/ggsw.rs index 39bb726..5174966 100644 --- a/poulpy-core/src/layouts/prepared/ggsw.rs +++ b/poulpy-core/src/layouts/prepared/ggsw.rs @@ -93,13 +93,7 @@ where A: GGSWInfos, { assert_eq!(self.ring_degree(), infos.n()); - self.alloc_ggsw_prepared( - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank(), - ) + self.alloc_ggsw_prepared(infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank()) } fn bytes_of_ggsw_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize { @@ -125,13 +119,7 @@ where A: GGSWInfos, { assert_eq!(self.ring_degree(), infos.n()); - self.bytes_of_ggsw_prepared( - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank(), - ) + self.bytes_of_ggsw_prepared(infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank()) } fn ggsw_prepare_tmp_bytes(&self, infos: &A) -> usize diff --git a/poulpy-core/src/layouts/prepared/glwe_automorphism_key.rs b/poulpy-core/src/layouts/prepared/glwe_automorphism_key.rs index 7cc5b6e..c60970e 100644 --- a/poulpy-core/src/layouts/prepared/glwe_automorphism_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_automorphism_key.rs @@ -17,9 +17,7 @@ where } fn automorphism_key_infos(&self) -> GGLWELayout { - self.get(self.keys().next().unwrap()) - .unwrap() - .gglwe_layout() + self.get(self.keys().next().unwrap()).unwrap().gglwe_layout() } } @@ -110,13 +108,7 @@ where infos.rank_out(), "rank_in != rank_out is not supported for AutomorphismKeyPrepared" ); - self.alloc_glwe_automorphism_key_prepared( - infos.base2k(), - infos.k(), - infos.rank(), - infos.dnum(), - infos.dsize(), - ) + self.alloc_glwe_automorphism_key_prepared(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize()) } fn bytes_of_glwe_automorphism_key_prepared( @@ -139,13 +131,7 @@ where infos.rank_out(), "rank_in != rank_out is not supported for AutomorphismKeyPrepared" ); - self.bytes_of_glwe_automorphism_key_prepared( - infos.base2k(), - infos.k(), - infos.rank(), - infos.dnum(), - infos.dsize(), - ) + self.bytes_of_glwe_automorphism_key_prepared(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize()) } fn prepare_glwe_automorphism_key_tmp_bytes(&self, infos: &A) -> usize diff --git a/poulpy-core/src/layouts/prepared/glwe_secret.rs b/poulpy-core/src/layouts/prepared/glwe_secret.rs index 8f4917d..e3a8c12 100644 --- a/poulpy-core/src/layouts/prepared/glwe_secret.rs +++ b/poulpy-core/src/layouts/prepared/glwe_secret.rs @@ -86,7 +86,6 @@ where { let mut res: GLWESecretPrepared<&mut [u8], _> = res.to_mut(); let other: GLWESecret<&[u8]> = other.to_ref(); - for i in 0..res.rank().into() { self.svp_prepare(&mut res.data, i, &other.data, i); } diff --git a/poulpy-core/src/layouts/prepared/glwe_secret_tensor.rs b/poulpy-core/src/layouts/prepared/glwe_secret_tensor.rs new file mode 100644 index 0000000..a27d52d --- /dev/null +++ b/poulpy-core/src/layouts/prepared/glwe_secret_tensor.rs @@ -0,0 +1,180 @@ +use poulpy_hal::{ + api::{SvpPPolAlloc, SvpPPolBytesOf}, + layouts::{Backend, Data, DataMut, DataRef, Module, SvpPPol, SvpPPolToMut, SvpPPolToRef, ZnxInfos}, +}; + +use crate::{ + GetDistribution, GetDistributionMut, + dist::Distribution, + layouts::{ + Base2K, Degree, GLWEInfos, GLWESecretPrepared, GLWESecretPreparedFactory, GLWESecretPreparedToMut, + GLWESecretPreparedToRef, GLWESecretTensor, GLWESecretToRef, GetDegree, LWEInfos, Rank, TorusPrecision, + }, +}; + +pub struct GLWESecretTensorPrepared { + pub(crate) data: SvpPPol, + pub(crate) rank: Rank, + pub(crate) dist: Distribution, +} + +impl GetDistribution for GLWESecretTensorPrepared { + fn dist(&self) -> &Distribution { + &self.dist + } +} + +impl GetDistributionMut for GLWESecretTensorPrepared { + fn dist_mut(&mut self) -> &mut Distribution { + &mut self.dist + } +} + +impl LWEInfos for GLWESecretTensorPrepared { + fn base2k(&self) -> Base2K { + Base2K(0) + } + + fn k(&self) -> TorusPrecision { + TorusPrecision(0) + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} +impl GLWEInfos for GLWESecretTensorPrepared { + fn rank(&self) -> Rank { + self.rank + } +} + +pub trait GLWESecretTensorPreparedFactory { + fn alloc_glwe_secret_tensor_prepared(&self, rank: Rank) -> GLWESecretTensorPrepared, B>; + fn alloc_glwe_secret_tensor_prepared_from_infos(&self, infos: &A) -> GLWESecretTensorPrepared, B> + where + A: GLWEInfos; + + fn bytes_of_glwe_secret_tensor_prepared(&self, rank: Rank) -> usize; + fn bytes_of_glwe_secret_tensor_prepared_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos; + + fn prepare_glwe_secret_tensor(&self, res: &mut R, other: &O) + where + R: GLWESecretPreparedToMut + GetDistributionMut, + O: GLWESecretToRef + GetDistribution; +} + +impl GLWESecretTensorPreparedFactory for Module +where + Self: GLWESecretPreparedFactory, +{ + fn alloc_glwe_secret_tensor_prepared(&self, rank: Rank) -> GLWESecretTensorPrepared, B> { + GLWESecretTensorPrepared { + data: self.svp_ppol_alloc(GLWESecretTensor::pairs(rank.into())), + rank, + dist: Distribution::NONE, + } + } + fn alloc_glwe_secret_tensor_prepared_from_infos(&self, infos: &A) -> GLWESecretTensorPrepared, B> + where + A: GLWEInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.alloc_glwe_secret_tensor_prepared(infos.rank()) + } + + fn bytes_of_glwe_secret_tensor_prepared(&self, rank: Rank) -> usize { + self.bytes_of_svp_ppol(GLWESecretTensor::pairs(rank.into())) + } + fn bytes_of_glwe_secret_tensor_prepared_from_infos(&self, infos: &A) -> usize + where + A: GLWEInfos, + { + assert_eq!(self.ring_degree(), infos.n()); + self.bytes_of_glwe_secret_prepared(infos.rank()) + } + + fn prepare_glwe_secret_tensor(&self, res: &mut R, other: &O) + where + R: GLWESecretPreparedToMut + GetDistributionMut, + O: GLWESecretToRef + GetDistribution, + { + self.prepare_glwe_secret(res, other); + } +} + +impl GLWESecretTensorPrepared, B> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GLWEInfos, + M: GLWESecretTensorPreparedFactory, + { + module.alloc_glwe_secret_tensor_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, rank: Rank) -> Self + where + M: GLWESecretTensorPreparedFactory, + { + module.alloc_glwe_secret_tensor_prepared(rank) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GLWEInfos, + M: GLWESecretTensorPreparedFactory, + { + module.bytes_of_glwe_secret_tensor_prepared_from_infos(infos) + } + + pub fn bytes_of(module: &M, rank: Rank) -> usize + where + M: GLWESecretTensorPreparedFactory, + { + module.bytes_of_glwe_secret_tensor_prepared(rank) + } +} + +impl GLWESecretTensorPrepared { + pub fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + pub fn rank(&self) -> Rank { + Rank(self.data.cols() as u32) + } +} + +impl GLWESecretTensorPrepared { + pub fn prepare(&mut self, module: &M, other: &O) + where + M: GLWESecretTensorPreparedFactory, + O: GLWESecretToRef + GetDistribution, + { + module.prepare_glwe_secret_tensor(self, other); + } +} + +impl GLWESecretPreparedToRef for GLWESecretTensorPrepared { + fn to_ref(&self) -> GLWESecretPrepared<&[u8], B> { + GLWESecretPrepared { + data: self.data.to_ref(), + dist: self.dist, + } + } +} + +impl GLWESecretPreparedToMut for GLWESecretTensorPrepared { + fn to_mut(&mut self) -> GLWESecretPrepared<&mut [u8], B> { + GLWESecretPrepared { + dist: self.dist, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs b/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs index 0304b37..b298a2a 100644 --- a/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs @@ -34,7 +34,7 @@ impl GLWEInfos for GLWETensorKeyPrepared { impl GGLWEInfos for GLWETensorKeyPrepared { fn rank_in(&self) -> Rank { - self.rank_out() + self.0.rank_in() } fn rank_out(&self) -> Rank { @@ -70,18 +70,7 @@ where where A: GGLWEInfos, { - assert_eq!( - infos.rank_in(), - infos.rank_out(), - "rank_in != rank_out is not supported for TensorKeyPrepared" - ); - self.alloc_tensor_key_prepared( - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank_out(), - ) + self.alloc_tensor_key_prepared(infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank_out()) } fn bytes_of_tensor_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { @@ -93,13 +82,7 @@ where where A: GGLWEInfos, { - self.bytes_of_tensor_key_prepared( - infos.base2k(), - infos.k(), - infos.rank(), - infos.dnum(), - infos.dsize(), - ) + self.bytes_of_tensor_key_prepared(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize()) } fn prepare_tensor_key_tmp_bytes(&self, infos: &A) -> usize diff --git a/poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs b/poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs index 675a73f..cd87cea 100644 --- a/poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs @@ -73,11 +73,7 @@ where 1, "rank_out > 1 is not supported for GLWEToLWEKeyPrepared" ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for GLWEToLWEKeyPrepared" - ); + debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWEKeyPrepared"); self.alloc_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } @@ -94,11 +90,7 @@ where 1, "rank_out > 1 is not supported for GLWEToLWEKeyPrepared" ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for GLWEToLWEKeyPrepared" - ); + debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWEKeyPrepared"); self.bytes_of_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } diff --git a/poulpy-core/src/layouts/prepared/lwe_switching_key.rs b/poulpy-core/src/layouts/prepared/lwe_switching_key.rs index 16f77eb..ba1974c 100644 --- a/poulpy-core/src/layouts/prepared/lwe_switching_key.rs +++ b/poulpy-core/src/layouts/prepared/lwe_switching_key.rs @@ -67,21 +67,9 @@ where where A: GGLWEInfos, { - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); + debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey"); + debug_assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey"); + debug_assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey"); self.alloc_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum()) } @@ -93,21 +81,9 @@ where where A: GGLWEInfos, { - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWESwitchingKey" - ); - debug_assert_eq!( - infos.rank_out().0, - 1, - "rank_out > 1 is not supported for LWESwitchingKey" - ); + debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey"); + debug_assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey"); + debug_assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey"); self.bytes_of_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum()) } diff --git a/poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs b/poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs index 25f08f8..c715c2f 100644 --- a/poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs +++ b/poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs @@ -69,16 +69,8 @@ where where A: GGLWEInfos, { - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWEKey" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWEToGLWEKey" - ); + debug_assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWEKey"); + debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWEKey"); self.alloc_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } @@ -90,16 +82,8 @@ where where A: GGLWEInfos, { - debug_assert_eq!( - infos.rank_in().0, - 1, - "rank_in > 1 is not supported for LWEToGLWEKey" - ); - debug_assert_eq!( - infos.dsize().0, - 1, - "dsize > 1 is not supported for LWEToGLWEKey" - ); + debug_assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWEKey"); + debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWEKey"); self.bytes_of_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } diff --git a/poulpy-core/src/layouts/prepared/mod.rs b/poulpy-core/src/layouts/prepared/mod.rs index 4d76cfb..6bbb390 100644 --- a/poulpy-core/src/layouts/prepared/mod.rs +++ b/poulpy-core/src/layouts/prepared/mod.rs @@ -5,6 +5,7 @@ mod glwe; mod glwe_automorphism_key; mod glwe_public_key; mod glwe_secret; +mod glwe_secret_tensor; mod glwe_switching_key; mod glwe_tensor_key; mod glwe_to_lwe_key; @@ -18,6 +19,7 @@ pub use glwe::*; pub use glwe_automorphism_key::*; pub use glwe_public_key::*; pub use glwe_secret::*; +pub use glwe_secret_tensor::*; pub use glwe_switching_key::*; pub use glwe_tensor_key::*; pub use glwe_to_lwe_key::*; diff --git a/poulpy-core/src/noise/gglwe.rs b/poulpy-core/src/noise/gglwe.rs index 12c907f..187c5e4 100644 --- a/poulpy-core/src/noise/gglwe.rs +++ b/poulpy-core/src/noise/gglwe.rs @@ -78,13 +78,7 @@ where let dsize: usize = res.dsize().into(); let (mut pt, scratch_1) = scratch.take_glwe_plaintext(res); pt.data_mut().zero(); - self.vec_znx_add_scalar_inplace( - &mut pt.data, - 0, - (dsize - 1) + res_row * dsize, - pt_want, - res_col, - ); + self.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + res_row * dsize, pt_want, res_col); self.glwe_noise(&res.at(res_row, res_col), &pt, sk_prepared, scratch_1) } } diff --git a/poulpy-core/src/noise/ggsw.rs b/poulpy-core/src/noise/ggsw.rs index 069d4f6..372ab3b 100644 --- a/poulpy-core/src/noise/ggsw.rs +++ b/poulpy-core/src/noise/ggsw.rs @@ -102,7 +102,7 @@ where self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, res_col - 1); let pt_big = self.vec_znx_idft_apply_consume(pt_dft); - self.vec_znx_big_normalize(base2k, &mut pt.data, 0, base2k, &pt_big, 0, scratch_2); + self.vec_znx_big_normalize(&mut pt.data, base2k, 0, 0, &pt_big, base2k, 0, scratch_2); } self.glwe_noise(&res.at(res_row, res_col), &pt, sk_prepared, scratch_1) diff --git a/poulpy-core/src/noise/glwe.rs b/poulpy-core/src/noise/glwe.rs index 5cfb512..75bcbdd 100644 --- a/poulpy-core/src/noise/glwe.rs +++ b/poulpy-core/src/noise/glwe.rs @@ -38,10 +38,7 @@ where where A: GLWEInfos, { - GLWEPlaintext::bytes_of_from_infos(infos) - + self - .glwe_normalize_tmp_bytes() - .max(self.glwe_decrypt_tmp_bytes(infos)) + GLWEPlaintext::bytes_of_from_infos(infos) + self.glwe_normalize_tmp_bytes().max(self.glwe_decrypt_tmp_bytes(infos)) } fn glwe_noise(&self, res: &R, pt_want: &P, sk_prepared: &S, scratch: &mut Scratch) -> Stats diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 3d839b8..b7d0a0e 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -1,80 +1,568 @@ use poulpy_hal::{ api::{ - BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy, + CnvPVecBytesOf, Convolution, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero, }, - layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos}, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxBig}, reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes, }; use crate::{ - ScratchTakeCore, + GGLWEProduct, ScratchTakeCore, layouts::{ - GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos, + Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWEPlaintext, GLWETensor, GLWETensorKeyPrepared, GLWEToMut, GLWEToRef, LWEInfos, TorusPrecision, }, }; -pub trait GLWETensoring +pub trait GLWEMulConst { + fn glwe_mul_const_tmp_bytes(&self, res: &R, res_offset: usize, a: &A, b_size: usize) -> usize + where + R: GLWEInfos, + A: GLWEInfos; + + fn glwe_mul_const(&self, res: &mut GLWE, res_offset: usize, a: &GLWE, b: &[i64], scratch: &mut Scratch) + where + R: DataMut, + A: DataRef; + + fn glwe_mul_const_inplace(&self, res: &mut GLWE, res_offset: usize, b: &[i64], scratch: &mut Scratch) + where + R: DataMut; +} + +impl GLWEMulConst for Module where - Self: BivariateTensoring + VecZnxIdftApplyConsume + VecZnxBigNormalize, + Self: Convolution + VecZnxBigBytesOf + VecZnxBigNormalize + VecZnxBigNormalizeTmpBytes, Scratch: ScratchTakeCore, { - /// res = (a (x) b) * 2^{k * a_base2k} - /// - /// # Requires - /// * a.base2k() == b.base2k() - /// * res.cols() >= a.cols() + b.cols() - 1 - /// - /// # Behavior - /// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k) - fn glwe_tensor(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + fn glwe_mul_const_tmp_bytes(&self, res: &R, res_offset: usize, a: &A, b_size: usize) -> usize where - R: GLWETensorToMut, - A: GLWEToRef, - B: GLWEPreparedToRef, + R: GLWEInfos, + A: GLWEInfos, { - let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut(); - let a: &GLWE<&[u8]> = &a.to_ref(); - let b: &GLWEPrepared<&[u8], BE> = &b.to_ref(); + let a_base2k: usize = a.base2k().as_usize(); + let res_base2k: usize = res.base2k().as_usize(); - assert_eq!(a.base2k(), b.base2k()); - assert_eq!(a.rank(), res.rank()); + let res_size: usize = (res.size() * res_base2k).div_ceil(a_base2k); + let res_big: usize = self.bytes_of_vec_znx_big(1, res_size); + let cnv: usize = self.cnv_by_const_apply_tmp_bytes(res_size, res_offset, a.size(), b_size); + let normalize: usize = self.vec_znx_big_normalize_tmp_bytes(); - let res_cols: usize = res.data.cols(); + res_big + cnv.max(normalize) + } - // Get tmp buffer of min precision between a_prec * b_prec and res_prec - let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize); + fn glwe_mul_const(&self, res: &mut GLWE, res_offset: usize, a: &GLWE, b: &[i64], scratch: &mut Scratch) + where + R: DataMut, + A: DataRef, + { + assert_eq!(res.rank(), a.rank()); - // DFT(res) = DFT(a) (x) DFT(b) - self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1); + let cols: usize = res.rank().as_usize() + 1; + let a_base2k: usize = a.base2k().as_usize(); + let res_base2k: usize = res.base2k().as_usize(); - // res = IDFT(res) - let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); + let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k { + (0, -((a_base2k - (res_offset % a_base2k)) as i64)) + } else { + ((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64) + }; - // Normalize and switches basis if required - for res_col in 0..res_cols { + let res_dft_size = res + .k() + .as_usize() + .div_ceil(a.base2k().as_usize()) + .min(a.size() + b.len() - res_offset_hi); + + let (mut res_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res_dft_size); + for i in 0..cols { + self.cnv_by_const_apply(&mut res_big, res_offset_hi, 0, a.data(), i, b, scratch_1); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, res_offset_lo, i, &res_big, a_base2k, 0, scratch_1); + } + } + + fn glwe_mul_const_inplace(&self, res: &mut GLWE, res_offset: usize, b: &[i64], scratch: &mut Scratch) + where + R: DataMut, + { + let cols: usize = res.rank().as_usize() + 1; + let res_base2k: usize = res.base2k().as_usize(); + + let (res_offset_hi, res_offset_lo) = if res_offset < res_base2k { + (0, -((res_base2k - (res_offset % res_base2k)) as i64)) + } else { + ((res_offset / res_base2k).saturating_sub(1), (res_offset % res_base2k) as i64) + }; + + let (mut res_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res.size()); + for i in 0..cols { + self.cnv_by_const_apply(&mut res_big, res_offset_hi, 0, res.data(), i, b, scratch_1); self.vec_znx_big_normalize( - res.base2k().into(), - &mut res.data, - res_col, - a.base2k().into(), + res.data_mut(), + res_base2k, + res_offset_lo, + i, &res_big, - res_col, + res_base2k, + 0, scratch_1, ); } } +} - // fn glwe_relinearize(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch) - // where - // R: GLWEToRef, - // A: GLWETensorToRef, - // T: GLWETensorKeyPreparedToRef, - // { - // } +impl GLWEMulPlain for Module +where + Self: Sized + + ModuleN + + CnvPVecBytesOf + + VecZnxDftBytesOf + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + Convolution + + VecZnxBigNormalizeTmpBytes, + Scratch: ScratchTakeCore, +{ + fn glwe_mul_plain_tmp_bytes(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GLWEInfos, + { + let ab_base2k: Base2K = a.base2k(); + assert_eq!(b.base2k(), ab_base2k); + + let cols: usize = res.rank().as_usize() + 1; + + let a_size: usize = a.size(); + let b_size: usize = b.size(); + let res_size: usize = res.size(); + + let cnv_pvec: usize = self.bytes_of_cnv_pvec_left(cols, a_size) + self.bytes_of_cnv_pvec_right(1, b_size); + let cnv_prep: usize = self + .cnv_prepare_left_tmp_bytes(a_size, a_size) + .max(self.cnv_prepare_right_tmp_bytes(a_size, a_size)); + let cnv_apply: usize = self.cnv_apply_dft_tmp_bytes(res_size, res_offset, a_size, b_size); + + let res_dft_size = res + .k() + .as_usize() + .div_ceil(a.base2k().as_usize()) + .min(a_size + b_size - res_offset / ab_base2k.as_usize()); + + let res_dft: usize = self.bytes_of_vec_znx_dft(1, res_dft_size); + let norm: usize = self.vec_znx_big_normalize_tmp_bytes(); + + cnv_pvec + cnv_prep + res_dft + cnv_apply.max(norm) + } + + fn glwe_mul_plain( + &self, + res: &mut GLWE, + res_offset: usize, + a: &GLWE, + b: &GLWEPlaintext, + scratch: &mut Scratch, + ) where + R: DataMut, + A: DataRef, + B: DataRef, + { + assert_eq!(res.rank(), a.rank()); + + let a_base2k: usize = a.base2k().as_usize(); + assert_eq!(b.base2k().as_usize(), a_base2k); + let res_base2k: usize = res.base2k().as_usize(); + + let cols: usize = res.rank().as_usize() + 1; + + let (mut a_prep, scratch_1) = scratch.take_cnv_pvec_left(self, cols, a.size()); + let (mut b_prep, scratch_2) = scratch_1.take_cnv_pvec_right(self, 1, b.size()); + + self.cnv_prepare_left(&mut a_prep, a.data(), scratch_2); + self.cnv_prepare_right(&mut b_prep, b.data(), scratch_2); + + let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k { + (0, -((a_base2k - (res_offset % a_base2k)) as i64)) + } else { + ((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64) + }; + + let res_dft_size = res + .k() + .as_usize() + .div_ceil(a.base2k().as_usize()) + .min(a.size() + b.size() - res_offset_hi); + + for i in 0..cols { + let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res_dft_size); + self.cnv_apply_dft(&mut res_dft, res_offset_hi, 0, &a_prep, i, &b_prep, 0, scratch_3); + let res_big = self.vec_znx_idft_apply_consume(res_dft); + + self.vec_znx_big_normalize(res.data_mut(), res_base2k, res_offset_lo, i, &res_big, a_base2k, 0, scratch_3); + } + } + + fn glwe_mul_plain_inplace(&self, res: &mut GLWE, res_offset: usize, a: &GLWEPlaintext, scratch: &mut Scratch) + where + R: DataMut, + A: DataRef, + { + assert_eq!(res.rank(), a.rank()); + + let a_base2k: usize = a.base2k().as_usize(); + let res_base2k: usize = res.base2k().as_usize(); + assert_eq!(res_base2k, a_base2k); + + let cols: usize = res.rank().as_usize() + 1; + + let (mut res_prep, scratch_1) = scratch.take_cnv_pvec_left(self, cols, res.size()); + let (mut a_prep, scratch_2) = scratch_1.take_cnv_pvec_right(self, 1, a.size()); + + self.cnv_prepare_left(&mut res_prep, res.data(), scratch_2); + self.cnv_prepare_right(&mut a_prep, a.data(), scratch_2); + + let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k { + (0, -((a_base2k - (res_offset % a_base2k)) as i64)) + } else { + ((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64) + }; + + let res_dft_size = res + .k() + .as_usize() + .div_ceil(a.base2k().as_usize()) + .min(a.size() + res.size() - res_offset_hi); + + for i in 0..cols { + let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res_dft_size); + self.cnv_apply_dft(&mut res_dft, res_offset, 0, &res_prep, i, &a_prep, 0, scratch_3); + let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, res_offset_lo, i, &res_big, a_base2k, 0, scratch_3); + } + } +} + +pub trait GLWEMulPlain { + fn glwe_mul_plain_tmp_bytes(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GLWEInfos; + + fn glwe_mul_plain( + &self, + res: &mut GLWE, + res_offset: usize, + a: &GLWE, + b: &GLWEPlaintext, + scratch: &mut Scratch, + ) where + R: DataMut, + A: DataRef, + B: DataRef; + + fn glwe_mul_plain_inplace(&self, res: &mut GLWE, res_offset: usize, a: &GLWEPlaintext, scratch: &mut Scratch) + where + R: DataMut, + A: DataRef; +} + +pub trait GLWETensoring { + fn glwe_tensor_apply_tmp_bytes(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GLWEInfos; + + /// res = (a (x) b) * 2^{res_offset * a_base2k} + /// + /// # Requires + /// * a.base2k() == b.base2k() + /// * a.rank() == b.rank() + /// + /// # Behavior + /// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k) + fn glwe_tensor_apply( + &self, + res: &mut GLWETensor, + res_offset: usize, + a: &GLWE, + b: &GLWE, + scratch: &mut Scratch, + ) where + R: DataMut, + A: DataRef, + B: DataRef; + + fn glwe_tensor_relinearize( + &self, + res: &mut GLWE, + a: &GLWETensor, + tsk: &GLWETensorKeyPrepared, + tsk_size: usize, + scratch: &mut Scratch, + ) where + R: DataMut, + A: DataRef, + B: DataRef; + + fn glwe_tensor_relinearize_tmp_bytes(&self, res: &R, a: &A, tsk: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos; +} + +impl GLWETensoring for Module +where + Self: Sized + + ModuleN + + CnvPVecBytesOf + + VecZnxDftBytesOf + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + Convolution + + VecZnxSubInplace + + VecZnxNegate + + VecZnxAddInplace + + VecZnxBigNormalizeTmpBytes + + VecZnxCopy + + VecZnxNormalize + + VecZnxDftApply + + GGLWEProduct + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes, + Scratch: ScratchTakeCore, +{ + fn glwe_tensor_apply_tmp_bytes(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GLWEInfos, + { + let ab_base2k: Base2K = a.base2k(); + assert_eq!(b.base2k(), ab_base2k); + + let cols: usize = res.rank().as_usize() + 1; + + let a_size: usize = a.size(); + let b_size: usize = b.size(); + let res_size: usize = res.size(); + + let cnv_pvec: usize = self.bytes_of_cnv_pvec_left(cols, a_size) + self.bytes_of_cnv_pvec_right(cols, b_size); + let cnv_prep: usize = self + .cnv_prepare_left_tmp_bytes(a_size, a_size) + .max(self.cnv_prepare_right_tmp_bytes(a_size, a_size)); + let cnv_apply: usize = self + .cnv_apply_dft_tmp_bytes(res_size, res_offset, a_size, b_size) + .max(self.cnv_pairwise_apply_dft_tmp_bytes(res_size, res_offset, a_size, b_size)); + + let res_dft_size = res + .k() + .as_usize() + .div_ceil(a.base2k().as_usize()) + .min(a_size + b_size - res_offset / ab_base2k.as_usize()); + + let res_dft: usize = self.bytes_of_vec_znx_dft(1, res_dft_size); + let tmp: usize = VecZnx::bytes_of(self.n(), 1, res.size()); + let norm: usize = self.vec_znx_big_normalize_tmp_bytes(); + + cnv_pvec + cnv_prep + res_dft + cnv_apply.max(tmp + norm) + } + + fn glwe_tensor_relinearize_tmp_bytes(&self, res: &R, a: &A, tsk: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos, + { + let a_base2k: usize = a.base2k().into(); + let key_base2k: usize = tsk.base2k().into(); + let res_base2k: usize = res.base2k().into(); + + let cols: usize = tsk.rank_out().as_usize() + 1; + let pairs: usize = tsk.rank_in().as_usize(); + + let a_dft_size: usize = (a.size() * a_base2k).div_ceil(key_base2k); + + let a_dft = self.bytes_of_vec_znx_dft(pairs, a_dft_size); + + let a_conv: usize = if a_base2k != key_base2k || res_base2k != key_base2k { + VecZnx::bytes_of(self.n(), 1, a_dft_size) + self.vec_znx_normalize_tmp_bytes() + } else { + 0 + }; + + let res_dft: usize = self.bytes_of_vec_znx_dft(cols, tsk.size()); + + let gglwe_product: usize = self.gglwe_product_dft_tmp_bytes(res.size(), a_dft_size, tsk); + + let big_normalize: usize = self.vec_znx_big_normalize_tmp_bytes(); + + a_dft.max(a_conv + big_normalize) + res_dft + gglwe_product.max(a_conv).max(big_normalize) + } + + fn glwe_tensor_relinearize( + &self, + res: &mut GLWE, + a: &GLWETensor, + tsk: &GLWETensorKeyPrepared, + tsk_size: usize, + scratch: &mut Scratch, + ) where + R: DataMut, + A: DataRef, + B: DataRef, + { + let a_base2k: usize = a.base2k().into(); + let key_base2k: usize = tsk.base2k().into(); + let res_base2k: usize = res.base2k().into(); + + assert_eq!(res.rank(), tsk.rank_out()); + assert_eq!(a.rank(), tsk.rank_out()); + + let cols: usize = tsk.rank_out().as_usize() + 1; + let pairs: usize = tsk.rank_in().as_usize(); + + let a_dft_size: usize = (a.size() * a_base2k).div_ceil(key_base2k); + + let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, pairs, a_dft_size); + + if a_base2k != key_base2k { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_dft_size); + for i in 0..pairs { + self.vec_znx_normalize(&mut a_conv, key_base2k, 0, 0, a.data(), a_base2k, cols + i, scratch_2); + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); + } + } else { + for i in 0..pairs { + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, a.data(), 0); + } + } + + let (mut res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk_size); // Todo optimise + + self.gglwe_product_dft(&mut res_dft, &a_dft, &tsk.0, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); + + if res_base2k == key_base2k { + for i in 0..cols { + self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); + } + } else { + let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), 1, a_dft_size); + for i in 0..cols { + self.vec_znx_normalize(&mut a_conv, key_base2k, 0, 0, a.data(), a_base2k, i, scratch_3); + self.vec_znx_big_add_small_inplace(&mut res_big, i, &a_conv, 0); + } + } + + for i in 0..(res.rank() + 1).into() { + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2); + } + } + + fn glwe_tensor_apply( + &self, + res: &mut GLWETensor, + res_offset: usize, + a: &GLWE, + b: &GLWE, + scratch: &mut Scratch, + ) where + R: DataMut, + A: DataRef, + B: DataRef, + { + let a_base2k: usize = a.base2k().as_usize(); + assert_eq!(b.base2k().as_usize(), a_base2k); + let res_base2k: usize = res.base2k().as_usize(); + + let cols: usize = res.rank().as_usize() + 1; + + let (mut a_prep, scratch_1) = scratch.take_cnv_pvec_left(self, cols, a.size()); + let (mut b_prep, scratch_2) = scratch_1.take_cnv_pvec_right(self, cols, b.size()); + + self.cnv_prepare_left(&mut a_prep, a.data(), scratch_2); + self.cnv_prepare_right(&mut b_prep, b.data(), scratch_2); + + // Example for rank=3 + // + // (a0, a1, a2, a3) x (b0, b1, b2, a3) + // L L L L R R R R + // + // c(1) = a0 * b0 <- (L(a0) * R(b0)) + // c(s1) = a0 * b1 + a1 * b0 <- (L(a0) + L(a1)) * (R(b0) + R(b1)) + NEG(L(a0) * R(b0)) + SUB(L(a1) * R(b1)) + // c(s2) = a0 * b2 + a2 * b0 <- (L(a0) + L(a2)) * (R(b0) + R(b2)) + NEG(L(a0) * R(b0)) + SUB(L(a2) * R(b2)) + // c(s3) = a0 * b3 + a3 * b0 <- (L(a0) + L(a3)) * (R(b0) + R(b3)) + NEG(L(a0) * R(b0)) + SUB(L(a3) * R(b3)) + // c(s1^2) = a1 * b1 <- (L(a1) * R(b1)) + // c(s1s2) = a1 * b2 + b2 * a1 <- (L(a1) + L(a2)) * (R(b1) + R(b2)) + NEG(L(a1) * R(b1)) + SUB(L(a2) * R(b2)) + // c(s1s3) = a1 * b3 + b3 * a1 <- (L(a1) + L(a3)) * (R(b1) + R(b3)) + NEG(L(a1) * R(b1)) + SUB(L(a3) * R(b3)) + // c(s2^2) = a2 * b2 <- (L(a2) * R(b2)) + // c(s2s3) = a2 * b3 + a3 * b2 <- (L(a2) + L(a3)) * (R(b2) + R(b3)) + NEG(L(a2) * R(b2)) + SUB(L(a3) * R(b3)) + // c(s3^2) = a3 * b3 <- (L(a3) * R(b3)) + + // Derive the offset. If res_offset < a_base2k, then we shift to a negative offset + // since the convolution doesn't support negative offset (yet). + let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k { + (0, -((a_base2k - (res_offset % a_base2k)) as i64)) + } else { + ((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64) + }; + + let res_dft_size = res + .k() + .as_usize() + .div_ceil(a.base2k().as_usize()) + .min(a.size() + b.size() - res_offset_hi); + + for i in 0..cols { + let col_i: usize = i * cols - (i * (i + 1) / 2); + + let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res_dft_size); + self.cnv_apply_dft(&mut res_dft, res_offset_hi, 0, &a_prep, i, &b_prep, i, scratch_3); + let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); + let (mut tmp, scratch_4) = scratch_3.take_vec_znx(self.n(), 1, res_dft_size); + self.vec_znx_big_normalize(&mut tmp, res_base2k, res_offset_lo, 0, &res_big, a_base2k, 0, scratch_4); + + self.vec_znx_copy(res.data_mut(), col_i + i, &tmp, 0); + + // Pre-subtracts + // res[i!=j] = NEG(a[i] * b[i]) + SUB(a[j] * b[j]) + for j in 0..cols { + if j != i { + if j < i { + let col_j = j * cols - (j * (j + 1) / 2); + self.vec_znx_sub_inplace(res.data_mut(), col_j + i, &tmp, 0); + } else { + self.vec_znx_negate(res.data_mut(), col_i + j, &tmp, 0); + } + } + } + } + + for i in 0..cols { + let col_i: usize = i * cols - (i * (i + 1) / 2); + + for j in i..cols { + if j != i { + // res_dft = (a[i] + a[j]) * (b[i] + b[j]) + let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res.size()); + self.cnv_pairwise_apply_dft(&mut res_dft, res_offset_hi, 0, &a_prep, &b_prep, i, j, scratch_3); + let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); + let (mut tmp, scratch_3) = scratch_3.take_vec_znx(self.n(), 1, res.size()); + self.vec_znx_big_normalize(&mut tmp, res_base2k, res_offset_lo, 0, &res_big, a_base2k, 0, scratch_3); + + self.vec_znx_add_inplace(res.data_mut(), col_i + j, &tmp, 0); + } + } + } + } } pub trait GLWEAdd @@ -431,9 +919,7 @@ where *tmp_slot = Some(tmp); // Get a mutable handle to the temp and normalize into it - let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot - .as_mut() - .expect("tmp_slot just set to Some, but found None"); + let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot.as_mut().expect("tmp_slot just set to Some, but found None"); self.glwe_normalize(tmp_ref, glwe, scratch2); @@ -470,9 +956,7 @@ where *tmp_slot = Some(tmp); // Get a mutable handle to the temp and normalize into it - let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot - .as_mut() - .expect("tmp_slot just set to Some, but found None"); + let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot.as_mut().expect("tmp_slot just set to Some, but found None"); self.glwe_normalize(tmp_ref, glwe, scratch2); @@ -493,16 +977,10 @@ where assert_eq!(a.n(), self.n() as u32); assert_eq!(res.rank(), a.rank()); + let res_base2k = res.base2k().into(); + for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_normalize( - res.base2k().into(), - res.data_mut(), - i, - a.base2k().into(), - a.data(), - i, - scratch, - ); + self.vec_znx_normalize(res.data_mut(), res_base2k, 0, i, a.data(), a.base2k().into(), i, scratch); } } diff --git a/poulpy-core/src/tests/mod.rs b/poulpy-core/src/tests/mod.rs index 9b69e20..e503522 100644 --- a/poulpy-core/src/tests/mod.rs +++ b/poulpy-core/src/tests/mod.rs @@ -20,6 +20,10 @@ mod poulpy_core { glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, // GLWE Base2k Conversion glwe_base2k_conv => crate::tests::test_suite::test_glwe_base2k_conversion, + // GLWE Tensoring + test_glwe_tensoring => crate::tests::test_suite::glwe_tensor::test_glwe_tensoring, + test_glwe_mul_plain => crate::tests::test_suite::glwe_tensor::test_glwe_mul_plain, + test_glwe_mul_const => crate::tests::test_suite::glwe_tensor::test_glwe_mul_const, // GLWE Keyswitch glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, @@ -88,6 +92,10 @@ mod poulpy_core { glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, // GLWE Base2k Conversion glwe_base2k_conv => crate::tests::test_suite::test_glwe_base2k_conversion, + // GLWE Tensoring + test_glwe_tensoring => crate::tests::test_suite::glwe_tensor::test_glwe_tensoring, + test_glwe_mul_plain => crate::tests::test_suite::glwe_tensor::test_glwe_mul_plain, + test_glwe_mul_const => crate::tests::test_suite::glwe_tensor::test_glwe_mul_const, // GLWE Keyswitch glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, diff --git a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs index 50d2d1b..ab3cce5 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs @@ -29,27 +29,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_in: usize = 17; - let base2k_key: usize = 13; - let base2k_out: usize = base2k_in; // MUST BE SAME + let in_base2k: usize = 17; + let key_base2k: usize = 13; + let out_base2k: usize = in_base2k; // MUST BE SAME let k_in: usize = 102; - let max_dsize: usize = k_in.div_ceil(base2k_key); + let max_dsize: usize = k_in.div_ceil(key_base2k); let p0: i64 = -1; let p1: i64 = -5; for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ksk: usize = k_in + base2k_key * dsize; + let k_ksk: usize = k_in + key_base2k * dsize; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); let dsize_in: usize = 1; - let dnum_in: usize = k_in / base2k_in; - let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_in / in_base2k; + let dnum_ksk: usize = k_in.div_ceil(key_base2k * dsize); let auto_key_in_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -58,7 +58,7 @@ where let auto_key_out_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -67,7 +67,7 @@ where let auto_key_apply_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -84,10 +84,7 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos) - .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( - module, - &auto_key_apply_infos, - )) + .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply_infos)) .max(GLWEAutomorphismKey::automorphism_tmp_bytes( module, &auto_key_out_infos, @@ -100,24 +97,10 @@ where sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 - auto_key_in.encrypt_sk( - module, - p0, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + auto_key_in.encrypt_sk(module, p0, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); // gglwe_{s2}(s1) -> s1 -> s2 - auto_key_apply.encrypt_sk( - module, - p1, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + auto_key_apply.encrypt_sk(module, p1, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut auto_key_apply_prepared: GLWEAutomorphismKeyPrepared, BE> = GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_infos); @@ -125,12 +108,7 @@ where auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - auto_key_out.automorphism( - module, - &auto_key_in, - &auto_key_apply_prepared, - scratch.borrow(), - ); + auto_key_out.automorphism(module, &auto_key_in, &auto_key_apply_prepared, scratch.borrow()); let mut sk_auto: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key_out_infos); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk @@ -152,7 +130,7 @@ where k_ksk, dnum_ksk, dsize, - base2k_key, + key_base2k, 0.5, 0.5, 0f64, @@ -171,11 +149,7 @@ where .std() .log2(); - assert!( - noise_have < max_noise + 0.5, - "{noise_have} > {}", - max_noise + 0.5 - ); + assert!(noise_have < max_noise + 0.5, "{noise_have} > {}", max_noise + 0.5); } } } @@ -196,26 +170,26 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 17; - let base2k_key: usize = 13; + let out_base2k: usize = 17; + let key_base2k: usize = 13; let k_out: usize = 102; - let max_dsize: usize = k_out.div_ceil(base2k_key); + let max_dsize: usize = k_out.div_ceil(key_base2k); let p0: i64 = -1; let p1: i64 = -5; for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ksk: usize = k_out + base2k_key * dsize; + let k_ksk: usize = k_out + key_base2k * dsize; let n: usize = module.n(); let dsize_in: usize = 1; - let dnum_in: usize = k_out / base2k_out; - let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_out / out_base2k; + let dnum_ksk: usize = k_out.div_ceil(key_base2k * dsize); let auto_key_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -224,7 +198,7 @@ where let auto_key_apply_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -248,24 +222,10 @@ where sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 - auto_key.encrypt_sk( - module, - p0, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + auto_key.encrypt_sk(module, p0, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); // gglwe_{s2}(s1) -> s1 -> s2 - auto_key_apply.encrypt_sk( - module, - p1, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + auto_key_apply.encrypt_sk(module, p1, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut auto_key_apply_prepared: GLWEAutomorphismKeyPrepared, BE> = GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_layout); @@ -296,7 +256,7 @@ where k_ksk, dnum_ksk, dsize, - base2k_key, + key_base2k, 0.5, 0.5, 0f64, @@ -315,11 +275,7 @@ where .std() .log2(); - assert!( - noise_have < max_noise + 0.5, - "{noise_have} {}", - max_noise + 0.5 - ); + assert!(noise_have < max_noise + 0.5, "{noise_have} {}", max_noise + 0.5); } } } diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs index 8b3679b..7de93d6 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -29,28 +29,28 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_in: usize = 17; - let base2k_key: usize = 13; - let base2k_out: usize = base2k_in; // MUST BE SAME + let in_base2k: usize = 17; + let key_base2k: usize = 13; + let out_base2k: usize = in_base2k; // MUST BE SAME let k_in: usize = 102; - let max_dsize: usize = k_in.div_ceil(base2k_key); + let max_dsize: usize = k_in.div_ceil(key_base2k); let p: i64 = -5; for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ksk: usize = k_in + base2k_key * dsize; + let k_ksk: usize = k_in + key_base2k * dsize; let k_tsk: usize = k_ksk; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum_in: usize = k_in / base2k_in; - let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_in / in_base2k; + let dnum_ksk: usize = k_in.div_ceil(key_base2k * dsize); let dsize_in: usize = 1; let ggsw_in_layout: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -59,7 +59,7 @@ where let ggsw_out_layout: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -68,7 +68,7 @@ where let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_tsk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -77,7 +77,7 @@ where let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -109,21 +109,8 @@ where let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); sk_prepared.prepare(module, &sk); - auto_key.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - tsk.encrypt_sk( - module, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + auto_key.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); + tsk.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); pt_scalar.fill_ternary_hw(0, n, &mut source_xs); @@ -143,20 +130,14 @@ where let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); tsk_prepared.prepare(module, &tsk, scratch.borrow()); - ct_out.automorphism( - module, - &ct_in, - &auto_key_prepared, - &tsk_prepared, - scratch.borrow(), - ); + ct_out.automorphism(module, &ct_in, &auto_key_prepared, &tsk_prepared, scratch.borrow()); module.vec_znx_automorphism_inplace(p, &mut pt_scalar.as_vec_znx_mut(), 0, scratch.borrow()); let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k_key * dsize, + key_base2k * dsize, col_j, var_xs, 0f64, @@ -199,25 +180,25 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 17; - let base2k_key: usize = 13; + let out_base2k: usize = 17; + let key_base2k: usize = 13; let k_out: usize = 102; - let max_dsize: usize = k_out.div_ceil(base2k_key); + let max_dsize: usize = k_out.div_ceil(key_base2k); let p: i64 = -1; for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ksk: usize = k_out + base2k_key * dsize; + let k_ksk: usize = k_out + key_base2k * dsize; let k_tsk: usize = k_ksk; let n: usize = module.n(); - let dnum_in: usize = k_out / base2k_out; - let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_out / out_base2k; + let dnum_ksk: usize = k_out.div_ceil(key_base2k * dsize); let dsize_in: usize = 1; let ggsw_out_layout: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -226,7 +207,7 @@ where let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_tsk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -235,7 +216,7 @@ where let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -266,21 +247,8 @@ where let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); sk_prepared.prepare(module, &sk); - auto_key.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - tsk.encrypt_sk( - module, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + auto_key.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); + tsk.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); pt_scalar.fill_ternary_hw(0, n, &mut source_xs); @@ -307,7 +275,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k_key * dsize, + key_base2k * dsize, col_j, var_xs, 0f64, @@ -327,10 +295,7 @@ where .std() .log2(); let noise_max: f64 = max_noise(col); - assert!( - noise_have <= noise_max, - "noise_have:{noise_have} > noise_max:{noise_max}", - ) + assert!(noise_have <= noise_max, "noise_have:{noise_have} > noise_max:{noise_max}",) } } } diff --git a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs index aa881c2..1eef9e2 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs @@ -30,37 +30,37 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_in: usize = 17; - let base2k_key: usize = 13; - let base2k_out: usize = 15; + let in_base2k: usize = 17; + let key_base2k: usize = 13; + let out_base2k: usize = 15; let k_in: usize = 102; - let max_dsize: usize = k_in.div_ceil(base2k_key); + let max_dsize: usize = k_in.div_ceil(key_base2k); let p: i64 = -5; for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ksk: usize = k_in + base2k_key * dsize; + let k_ksk: usize = k_in + key_base2k * dsize; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k_key * dsize); + let dnum: usize = k_in.div_ceil(key_base2k * dsize); let ct_in_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_in.into(), rank: rank.into(), }; let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), rank: rank.into(), }; let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_out.into(), rank: rank.into(), dnum: dnum.into(), @@ -77,7 +77,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k_in, &mut pt_in.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(in_base2k, &mut pt_in.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) @@ -92,23 +92,9 @@ where let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); sk_prepared.prepare(module, &sk); - autokey.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + autokey.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); - ct_in.encrypt_sk( - module, - &pt_in, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_in.encrypt_sk(module, &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow()); let mut autokey_prepared: GLWEAutomorphismKeyPrepared, BE> = GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &autokey_infos); @@ -121,7 +107,7 @@ where k_ksk, dnum, max_dsize, - base2k_key, + key_base2k, 0.5, 0.5, 0f64, @@ -135,13 +121,7 @@ where module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow()); module.vec_znx_automorphism_inplace(p, &mut pt_out.data, 0, scratch.borrow()); - assert!( - ct_out - .noise(module, &pt_out, &sk_prepared, scratch.borrow()) - .std() - .log2() - <= max_noise + 1.0 - ) + assert!(ct_out.noise(module, &pt_out, &sk_prepared, scratch.borrow()).std().log2() <= max_noise + 1.0) } } } @@ -161,29 +141,29 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 17; - let base2k_key: usize = 13; + let out_base2k: usize = 17; + let key_base2k: usize = 13; let k_out: usize = 102; - let max_dsize: usize = k_out.div_ceil(base2k_key); + let max_dsize: usize = k_out.div_ceil(key_base2k); let p: i64 = -5; for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ksk: usize = k_out + base2k_key * dsize; + let k_ksk: usize = k_out + key_base2k * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k_key * dsize); + let dnum: usize = k_out.div_ceil(key_base2k * dsize); let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), rank: rank.into(), }; let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), rank: rank.into(), dnum: dnum.into(), @@ -198,7 +178,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k_out, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(out_base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) @@ -213,14 +193,7 @@ where let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); sk_prepared.prepare(module, &sk); - autokey.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + autokey.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); ct.encrypt_sk( module, @@ -242,7 +215,7 @@ where k_ksk, dnum, dsize, - base2k_key, + key_base2k, 0.5, 0.5, 0f64, @@ -255,12 +228,7 @@ where module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); - assert!( - ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()) - .std() - .log2() - <= max_noise + 1.0 - ) + assert!(ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2() <= max_noise + 1.0) } } } diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index 9ee70e2..fb04175 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -65,33 +65,19 @@ where let pt_in: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos_in); let pt_out: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos_out); - ct_in.encrypt_sk( - module, - &pt_in, - &sk_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_in.encrypt_sk(module, &pt_in, &sk_prep, &mut source_xa, &mut source_xe, scratch.borrow()); let mut data: Vec = (0..module.n()).map(|_| Float::with_val(128, 0)).collect(); - ct_in - .data() - .decode_vec_float(ct_in.base2k().into(), 0, &mut data); + ct_in.data().decode_vec_float(ct_in.base2k().into(), 0, &mut data); ct_out.fill_uniform(ct_out.base2k().into(), &mut source_xa); module.glwe_normalize(&mut ct_out, &ct_in, scratch.borrow()); let mut data_conv: Vec = (0..module.n()).map(|_| Float::with_val(128, 0)).collect(); - ct_out - .data() - .decode_vec_float(ct_out.base2k().into(), 0, &mut data_conv); + ct_out.data().decode_vec_float(ct_out.base2k().into(), 0, &mut data_conv); assert!( - ct_out - .noise(module, &pt_out, &sk_prep, scratch.borrow()) - .std() - .log2() + ct_out.noise(module, &pt_out, &sk_prep, scratch.borrow()).std().log2() <= -(ct_out.k().as_u32() as f64) + SIGMA.log2() + 0.50 ) } @@ -162,14 +148,7 @@ where lwe_pt.encode_i64(data, k_lwe_pt); let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); - lwe_ct.encrypt_sk( - module, - &lwe_pt, - &sk_lwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe, scratch.borrow()); let mut ksk: LWEToGLWEKey> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos); @@ -195,11 +174,12 @@ where let mut lwe_pt_conv = LWEPlaintext::alloc(glwe_pt.base2k(), lwe_pt.k()); module.vec_znx_normalize( - glwe_pt.base2k().as_usize(), lwe_pt_conv.data_mut(), + glwe_pt.base2k().as_usize(), + 0, 0, - lwe_pt.base2k().as_usize(), lwe_pt.data(), + lwe_pt.base2k().as_usize(), 0, scratch.borrow(), ); @@ -287,14 +267,7 @@ where let mut ksk: GLWEToLWEKey> = GLWEToLWEKey::alloc_from_infos(&glwe_to_lwe_infos); - ksk.encrypt_sk( - module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ksk.encrypt_sk(module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); @@ -309,11 +282,12 @@ where let mut glwe_pt_conv = GLWEPlaintext::alloc(glwe_ct.n(), lwe_pt.base2k(), lwe_pt.k()); module.vec_znx_normalize( - lwe_pt.base2k().as_usize(), glwe_pt_conv.data_mut(), + lwe_pt.base2k().as_usize(), + 0, 0, - glwe_ct.base2k().as_usize(), glwe_pt.data(), + glwe_ct.base2k().as_usize(), 0, scratch.borrow(), ); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs index 9f917af..b22b3ae 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs @@ -53,23 +53,15 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( - module, &atk_infos, - )); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &atk_infos)); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&atk_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let p = -5; - atk.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + atk.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut sk_out: GLWESecret> = sk.clone(); (0..atk.rank().into()).for_each(|i| { @@ -90,14 +82,7 @@ where for col in 0..atk.rank().as_usize() { assert!( atk.key - .noise( - module, - row, - col, - &sk.data, - &sk_out_prepared, - scratch.borrow() - ) + .noise(module, row, col, &sk.data, &sk_out_prepared, scratch.borrow()) .std() .log2() <= max_noise @@ -145,9 +130,8 @@ where let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWEAutomorphismKeyCompressed::encrypt_sk_tmp_bytes( - module, &atk_infos, - )); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWEAutomorphismKeyCompressed::encrypt_sk_tmp_bytes(module, &atk_infos)); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&atk_infos); sk.fill_ternary_prob(0.5, &mut source_xs); @@ -180,14 +164,7 @@ where for col in 0..atk.rank().as_usize() { let noise_have = atk .key - .noise( - module, - row, - col, - &sk.data, - &sk_out_prepared, - scratch.borrow(), - ) + .noise(module, row, col, &sk.data, &sk_out_prepared, scratch.borrow()) .std() .log2(); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs index 36d0a72..147e1cd 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs @@ -66,14 +66,7 @@ where let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); sk_out_prepared.prepare(module, &sk_out); - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); let max_noise: f64 = SIGMA.log2() - (ksk.k().as_usize() as f64) + 0.5; @@ -81,14 +74,7 @@ where for col in 0..ksk.rank_in().as_usize() { let noise_have = ksk .key - .noise( - module, - row, - col, - &sk_in.data, - &sk_out_prepared, - scratch.borrow(), - ) + .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow()) .std() .log2(); @@ -144,10 +130,8 @@ where let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes( - module, - &gglwe_infos, - )); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes(module, &gglwe_infos)); let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); @@ -159,14 +143,7 @@ where let seed_xa = [1u8; 32]; - ksk_compressed.encrypt_sk( - module, - &sk_in, - &sk_out, - seed_xa, - &mut source_xe, - scratch.borrow(), - ); + ksk_compressed.encrypt_sk(module, &sk_in, &sk_out, seed_xa, &mut source_xe, scratch.borrow()); let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_infos); ksk.decompress(module, &ksk_compressed); @@ -177,14 +154,7 @@ where for col in 0..ksk.rank_in().as_usize() { let noise_have = ksk .key - .noise( - module, - row, - col, - &sk_in.data, - &sk_out_prepared, - scratch.borrow(), - ) + .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow()) .std() .log2(); @@ -269,14 +239,7 @@ where for row in 0..ksk.dnum().as_usize() { for col in 0..ksk.rank_in().as_usize() { let noise_have = ksk - .noise( - module, - row, - col, - &sk_in.data, - &sk_out_prepared, - scratch.borrow(), - ) + .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow()) .std() .log2(); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs index c508bf4..8e5f25c 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs @@ -55,13 +55,7 @@ where let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); sk_prepared.prepare(module, &sk); - key.encrypt_sk( - module, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + key.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); sk_tensor.prepare(module, &sk, scratch.borrow()); @@ -72,12 +66,7 @@ where for i in 0..rank { for j in 0..rank { - module.vec_znx_copy( - &mut pt_want.as_vec_znx_mut(), - j, - &sk_tensor.at(i, j).as_vec_znx(), - 0, - ); + module.vec_znx_copy(&mut pt_want.as_vec_znx_mut(), j, &sk_tensor.at(i, j).as_vec_znx(), 0); } let ksk: &GGLWE> = key.at(i); @@ -127,9 +116,8 @@ where let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEToGGSWKeyCompressed::encrypt_sk_tmp_bytes( - module, &key_infos, - )); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GGLWEToGGSWKeyCompressed::encrypt_sk_tmp_bytes(module, &key_infos)); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&key_infos); sk.fill_ternary_prob(0.5, &mut source_xs); @@ -152,12 +140,7 @@ where for i in 0..rank { for j in 0..rank { - module.vec_znx_copy( - &mut pt_want.as_vec_znx_mut(), - j, - &sk_tensor.at(i, j).as_vec_znx(), - 0, - ); + module.vec_znx_copy(&mut pt_want.as_vec_znx_mut(), j, &sk_tensor.at(i, j).as_vec_znx(), 0); } let ksk: &GGLWE> = key.at(i); diff --git a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs index 8fd0b7a..40194f1 100644 --- a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs @@ -121,14 +121,7 @@ where let seed_xa: [u8; 32] = [1u8; 32]; - ct_compressed.encrypt_sk( - module, - &pt_scalar, - &sk_prepared, - seed_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_compressed.encrypt_sk(module, &pt_scalar, &sk_prepared, seed_xa, &mut source_xe, scratch.borrow()); let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs index 81901ad..60f82e4 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs @@ -68,10 +68,7 @@ where scratch.borrow(), ); - let noise_have: f64 = ct - .noise(module, &pt_want, &sk_prepared, scratch.borrow()) - .std() - .log2(); + let noise_have: f64 = ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2(); let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5; assert!(noise_have <= noise_want + 0.2); @@ -126,22 +123,12 @@ where let seed_xa: [u8; 32] = [1u8; 32]; - ct_compressed.encrypt_sk( - module, - &pt_want, - &sk_prepared, - seed_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_compressed.encrypt_sk(module, &pt_want, &sk_prepared, seed_xa, &mut source_xe, scratch.borrow()); let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); ct.decompress(module, &ct_compressed); - let noise_have: f64 = ct - .noise(module, &pt_want, &sk_prepared, scratch.borrow()) - .std() - .log2(); + let noise_have: f64 = ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2(); let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5; assert!(noise_have <= noise_want + 0.2); } @@ -186,18 +173,9 @@ where let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); - ct.encrypt_zero_sk( - module, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct.encrypt_zero_sk(module, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow()); - let noise_have: f64 = ct - .noise(module, &pt, &sk_prepared, scratch.borrow()) - .std() - .log2(); + let noise_have: f64 = ct.noise(module, &pt, &sk_prepared, scratch.borrow()).std().log2(); let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5; assert!(noise_have <= noise_want + 0.2); } @@ -265,10 +243,7 @@ where scratch.borrow(), ); - let noise_have: f64 = ct - .noise(module, &pt_want, &sk_prepared, scratch.borrow()) - .std() - .log2(); + let noise_have: f64 = ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2(); let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64); assert!(noise_have <= noise_want + 0.2); } diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs index a0dcbf5..9e9cb21 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs @@ -46,23 +46,14 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWETensorKey::encrypt_sk_tmp_bytes( - module, - &tensor_key_infos, - )); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key_infos)); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&tensor_key_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); sk_prepared.prepare(module, &sk); - tensor_key.encrypt_sk( - module, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + tensor_key.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); sk_tensor.prepare(module, &sk, scratch.borrow()); @@ -74,14 +65,7 @@ where assert!( tensor_key .0 - .noise( - module, - row, - col, - &sk_tensor.data, - &sk_prepared, - scratch.borrow() - ) + .noise(module, row, col, &sk_tensor.data, &sk_prepared, scratch.borrow()) .std() .log2() <= max_noise @@ -124,10 +108,8 @@ where let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWETensorKeyCompressed::encrypt_sk_tmp_bytes( - module, - &tensor_key_infos, - )); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWETensorKeyCompressed::encrypt_sk_tmp_bytes(module, &tensor_key_infos)); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&tensor_key_infos); sk.fill_ternary_prob(0.5, &mut source_xs); @@ -151,14 +133,7 @@ where assert!( tensor_key .0 - .noise( - module, - row, - col, - &sk_tensor.data, - &sk_prepared, - scratch.borrow() - ) + .noise(module, row, col, &sk_tensor.data, &sk_prepared, scratch.borrow()) .std() .log2() <= max_noise diff --git a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs index 7161546..a75476c 100644 --- a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs @@ -28,26 +28,26 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_in: usize = 17; - let base2k_key: usize = 13; - let base2k_out: usize = base2k_in; // MUST BE SAME + let in_base2k: usize = 17; + let key_base2k: usize = 13; + let out_base2k: usize = in_base2k; // MUST BE SAME let k_in: usize = 102; - let max_dsize: usize = k_in.div_ceil(base2k_key); + let max_dsize: usize = k_in.div_ceil(key_base2k); for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { for dsize in 1_usize..max_dsize + 1 { - let k_ggsw: usize = k_in + base2k_key * dsize; + let k_ggsw: usize = k_in + key_base2k * dsize; let k_out: usize = k_in; // Better capture noise. let n: usize = module.n(); - let dnum_in: usize = k_in / base2k_in; - let dnum: usize = k_in.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_in / in_base2k; + let dnum: usize = k_in.div_ceil(key_base2k * dsize); let dsize_in: usize = 1; let gglwe_in_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -57,7 +57,7 @@ where let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -67,7 +67,7 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ggsw.into(), dnum: dnum.into(), dsize: dsize.into(), @@ -106,14 +106,7 @@ where sk_out_prepared.prepare(module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_in.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_gglwe_in.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); ct_rgsw.encrypt_sk( module, @@ -131,12 +124,7 @@ where ct_gglwe_out.external_product(module, &ct_gglwe_in, &ct_rgsw_prepared, scratch.borrow()); (0..rank_in).for_each(|i| { - module.vec_znx_rotate_inplace( - r as i64, - &mut sk_in.data.as_vec_znx_mut(), - i, - scratch.borrow(), - ); // * X^{r} + module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data.as_vec_znx_mut(), i, scratch.borrow()); // * X^{r} }); let var_gct_err_lhs: f64 = SIGMA * SIGMA; @@ -148,7 +136,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k_key * dsize, + key_base2k * dsize, var_xs, var_msg, var_a0_err, @@ -165,14 +153,7 @@ where assert!( ct_gglwe_out .key - .noise( - module, - row, - col, - &sk_in.data, - &sk_out_prepared, - scratch.borrow() - ) + .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow()) .std() .log2() <= max_noise + 0.5 @@ -197,25 +178,25 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 17; - let base2k_key: usize = 13; + let out_base2k: usize = 17; + let key_base2k: usize = 13; let k_out: usize = 102; - let max_dsize: usize = k_out.div_ceil(base2k_key); + let max_dsize: usize = k_out.div_ceil(key_base2k); for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { for dsize in 1_usize..max_dsize + 1 { - let k_ggsw: usize = k_out + base2k_key * dsize; + let k_ggsw: usize = k_out + key_base2k * dsize; let n: usize = module.n(); - let dnum_in: usize = k_out / base2k_out; - let dnum: usize = k_out.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_out / out_base2k; + let dnum: usize = k_out.div_ceil(key_base2k * dsize); let dsize_in: usize = 1; let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -225,7 +206,7 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ggsw.into(), dnum: dnum.into(), dsize: dsize.into(), @@ -263,14 +244,7 @@ where sk_out_prepared.prepare(module, &sk_out); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_gglwe.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); ct_rgsw.encrypt_sk( module, @@ -288,12 +262,7 @@ where ct_gglwe.external_product_inplace(module, &ct_rgsw_prepared, scratch.borrow()); (0..rank_in).for_each(|i| { - module.vec_znx_rotate_inplace( - r as i64, - &mut sk_in.data.as_vec_znx_mut(), - i, - scratch.borrow(), - ); // * X^{r} + module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data.as_vec_znx_mut(), i, scratch.borrow()); // * X^{r} }); let var_gct_err_lhs: f64 = SIGMA * SIGMA; @@ -305,7 +274,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k_key * dsize, + key_base2k * dsize, var_xs, var_msg, var_a0_err, @@ -322,14 +291,7 @@ where assert!( ct_gglwe .key - .noise( - module, - row, - col, - &sk_in.data, - &sk_out_prepared, - scratch.borrow() - ) + .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow()) .std() .log2() <= max_noise + 0.5 diff --git a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs index 501a161..f8aec82 100644 --- a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs @@ -26,26 +26,26 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_in: usize = 17; - let base2k_key: usize = 13; - let base2k_out: usize = base2k_in; // MUST BE SAME + let in_base2k: usize = 17; + let key_base2k: usize = 13; + let out_base2k: usize = in_base2k; // MUST BE SAME let k_in: usize = 102; - let max_dsize: usize = k_in.div_ceil(base2k_key); + let max_dsize: usize = k_in.div_ceil(key_base2k); for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_apply: usize = k_in + base2k_key * dsize; + let k_apply: usize = k_in + key_base2k * dsize; let k_out: usize = k_in; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k_key * dsize); - let dnum_in: usize = k_in / base2k_in; + let dnum: usize = k_in.div_ceil(key_base2k * dsize); + let dnum_in: usize = k_in / in_base2k; let dsize_in: usize = 1; let ggsw_in_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -54,7 +54,7 @@ where let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -63,7 +63,7 @@ where let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_apply.into(), dnum: dnum.into(), dsize: dsize.into(), @@ -107,14 +107,7 @@ where scratch.borrow(), ); - ggsw_in.encrypt_sk( - module, - &pt_in, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ggsw_in.encrypt_sk(module, &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow()); let mut ct_rhs_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); ct_rhs_prepared.prepare(module, &ggsw_apply, scratch.borrow()); @@ -133,7 +126,7 @@ where let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( n as f64, - base2k_key * dsize, + key_base2k * dsize, 0.5, var_msg, var_a0_err, @@ -173,23 +166,23 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 17; - let base2k_key: usize = 13; + let out_base2k: usize = 17; + let key_base2k: usize = 13; let k_out: usize = 102; - let max_dsize: usize = k_out.div_ceil(base2k_key); + let max_dsize: usize = k_out.div_ceil(key_base2k); for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_apply: usize = k_out + base2k_key * dsize; + let k_apply: usize = k_out + key_base2k * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(dsize * base2k_key); - let dnum_in: usize = k_out / base2k_out; + let dnum: usize = k_out.div_ceil(dsize * key_base2k); + let dnum_in: usize = k_out / out_base2k; let dsize_in: usize = 1; let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -198,7 +191,7 @@ where let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_apply.into(), dnum: dnum.into(), dsize: dsize.into(), @@ -242,14 +235,7 @@ where scratch.borrow(), ); - ggsw_out.encrypt_sk( - module, - &pt_in, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ggsw_out.encrypt_sk(module, &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow()); let mut ct_rhs_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); ct_rhs_prepared.prepare(module, &ggsw_apply, scratch.borrow()); @@ -268,7 +254,7 @@ where let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( n as f64, - base2k_key * dsize, + key_base2k * dsize, 0.5, var_msg, var_a0_err, diff --git a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs index 49ef81f..aeca55a 100644 --- a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs @@ -29,14 +29,14 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_in: usize = 17; - let base2k_key: usize = 13; - let base2k_out: usize = 15; + let in_base2k: usize = 17; + let key_base2k: usize = 13; + let out_base2k: usize = 15; let k_in: usize = 102; - let max_dsize: usize = k_in.div_ceil(base2k_key); + let max_dsize: usize = k_in.div_ceil(key_base2k); for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ggsw: usize = k_in + base2k_key * dsize; + let k_ggsw: usize = k_in + key_base2k * dsize; let k_out: usize = k_ggsw; // Better capture noise let n: usize = module.n(); @@ -44,21 +44,21 @@ where let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_in.into(), rank: rank.into(), }; let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), rank: rank.into(), }; let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ggsw.into(), dnum: dnum.into(), dsize: dsize.into(), @@ -77,7 +77,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - module.vec_znx_fill_uniform(base2k_in, &mut pt_in.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(in_base2k, &mut pt_in.data, 0, &mut source_xa); pt_in.data.at_mut(0, 0)[1] = 1; @@ -106,14 +106,7 @@ where scratch.borrow(), ); - glwe_in.encrypt_sk( - module, - &pt_in, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + glwe_in.encrypt_sk(module, &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow()); let mut ct_ggsw_prepared: GGSWPrepared, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); ct_ggsw_prepared.prepare(module, &ggsw_apply, scratch.borrow()); @@ -133,7 +126,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k_key * max_dsize, + key_base2k * max_dsize, 0.5, var_msg, var_a0_err, @@ -145,13 +138,7 @@ where k_ggsw, ); - assert!( - glwe_out - .noise(module, &pt_out, &sk_prepared, scratch.borrow()) - .std() - .log2() - <= max_noise + 1.0 - ) + assert!(glwe_out.noise(module, &pt_out, &sk_prepared, scratch.borrow()).std().log2() <= max_noise + 1.0) } } } @@ -170,28 +157,28 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 17; - let base2k_key: usize = 13; + let out_base2k: usize = 17; + let key_base2k: usize = 13; let k_out: usize = 102; - let max_dsize: usize = k_out.div_ceil(base2k_key); + let max_dsize: usize = k_out.div_ceil(key_base2k); for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ggsw: usize = k_out + base2k_key * dsize; + let k_ggsw: usize = k_out + key_base2k * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k_out * max_dsize); + let dnum: usize = k_out.div_ceil(out_base2k * max_dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), rank: rank.into(), }; let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ggsw.into(), dnum: dnum.into(), dsize: dsize.into(), @@ -208,7 +195,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - module.vec_znx_fill_uniform(base2k_out, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(out_base2k, &mut pt_want.data, 0, &mut source_xa); pt_want.data.at_mut(0, 0)[1] = 1; @@ -262,7 +249,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k_key * max_dsize, + key_base2k * max_dsize, 0.5, var_msg, var_a0_err, @@ -274,13 +261,7 @@ where k_ggsw, ); - assert!( - glwe_out - .noise(module, &pt_want, &sk_prepared, scratch.borrow()) - .std() - .log2() - <= max_noise + 1.0 - ) + assert!(glwe_out.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2() <= max_noise + 1.0) } } } diff --git a/poulpy-core/src/tests/test_suite/glwe_packer.rs b/poulpy-core/src/tests/test_suite/glwe_packer.rs index e663836..ebf3247 100644 --- a/poulpy-core/src/tests/test_suite/glwe_packer.rs +++ b/poulpy-core/src/tests/test_suite/glwe_packer.rs @@ -33,26 +33,26 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let n: usize = module.n(); - let base2k_out: usize = 15; - let base2k_key: usize = 10; + let out_base2k: usize = 15; + let key_base2k: usize = 10; let k_ct: usize = 36; let pt_k: usize = 18; let rank: usize = 3; let dsize: usize = 1; - let k_ksk: usize = k_ct + base2k_key * dsize; + let k_ksk: usize = k_ct + key_base2k * dsize; - let dnum: usize = k_ct.div_ceil(base2k_key * dsize); + let dnum: usize = k_ct.div_ceil(key_base2k * dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_ct.into(), rank: rank.into(), }; let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), rank: rank.into(), dsize: dsize.into(), @@ -84,14 +84,7 @@ where let mut auto_keys: HashMap, BE>> = HashMap::new(); let mut tmp: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&key_infos); gal_els.iter().for_each(|gal_el| { - tmp.encrypt_sk( - module, - *gal_el, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + tmp.encrypt_sk(module, *gal_el, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut atk_prepared: GLWEAutomorphismKeyPrepared, BE> = GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp); atk_prepared.prepare(module, &tmp, scratch.borrow()); @@ -104,26 +97,12 @@ where let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); - ct.encrypt_sk( - module, - &pt, - &sk_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct.encrypt_sk(module, &pt, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); let log_n: usize = module.log_n(); (0..n >> log_batch).for_each(|i| { - ct.encrypt_sk( - module, - &pt, - &sk_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct.encrypt_sk(module, &pt, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); module.glwe_rotate_inplace(-(1 << log_batch), &mut pt, scratch.borrow()); // X^-batch * pt @@ -153,10 +132,7 @@ where let noise_have: f64 = pt.stats().std().log2(); - assert!( - noise_have < -((k_ct - base2k_out) as f64), - "noise: {noise_have}" - ); + assert!(noise_have < -((k_ct - out_base2k) as f64), "noise: {noise_have}"); } #[inline(always)] diff --git a/poulpy-core/src/tests/test_suite/glwe_packing.rs b/poulpy-core/src/tests/test_suite/glwe_packing.rs index 666e4d4..8c01dc1 100644 --- a/poulpy-core/src/tests/test_suite/glwe_packing.rs +++ b/poulpy-core/src/tests/test_suite/glwe_packing.rs @@ -35,26 +35,26 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let n: usize = module.n(); - let base2k_out: usize = 15; - let base2k_key: usize = 10; + let out_base2k: usize = 15; + let key_base2k: usize = 10; let k_ct: usize = 36; - let pt_k: usize = base2k_out; + let pt_k: usize = out_base2k; let rank: usize = 3; let dsize: usize = 1; - let k_ksk: usize = k_ct + base2k_key * dsize; + let k_ksk: usize = k_ct + key_base2k * dsize; - let dnum: usize = k_ct.div_ceil(base2k_key * dsize); + let dnum: usize = k_ct.div_ceil(key_base2k * dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_ct.into(), rank: rank.into(), }; let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), rank: rank.into(), dsize: dsize.into(), @@ -63,9 +63,7 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) - .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( - module, &key_infos, - )) + .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &key_infos)) .max(module.glwe_pack_tmp_bytes(&glwe_out_infos, &key_infos)), ); @@ -88,14 +86,7 @@ where let mut auto_keys: HashMap, BE>> = HashMap::new(); let mut tmp: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&key_infos); gal_els.iter().for_each(|gal_el| { - tmp.encrypt_sk( - module, - *gal_el, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + tmp.encrypt_sk(module, *gal_el, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut atk_prepared: GLWEAutomorphismKeyPrepared, BE> = GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp); atk_prepared.prepare(module, &tmp, scratch.borrow()); @@ -106,14 +97,7 @@ where .step_by(5) .map(|_| { let mut ct = GLWE::alloc_from_infos(&glwe_out_infos); - ct.encrypt_sk( - module, - &pt, - &sk_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct.encrypt_sk(module, &pt, &sk_prep, &mut source_xa, &mut source_xe, scratch.borrow()); module.glwe_rotate_inplace(-5, &mut pt, scratch.borrow()); // X^-batch * pt ct }) @@ -139,10 +123,5 @@ where pt_want.encode_vec_i64(&data, pt_k.into()); - assert!( - res.noise(module, &pt_want, &sk_prep, scratch.borrow()) - .std() - .log2() - <= ((k_ct - base2k_out) as f64) - ); + assert!(res.noise(module, &pt_want, &sk_prep, scratch.borrow()).std().log2() <= ((k_ct - out_base2k) as f64)); } diff --git a/poulpy-core/src/tests/test_suite/glwe_tensor.rs b/poulpy-core/src/tests/test_suite/glwe_tensor.rs new file mode 100644 index 0000000..7640e3c --- /dev/null +++ b/poulpy-core/src/tests/test_suite/glwe_tensor.rs @@ -0,0 +1,390 @@ +use poulpy_hal::{ + api::{ + ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy, VecZnxFillUniform, VecZnxNormalize, + VecZnxNormalizeInplace, + }, + layouts::{Backend, FillUniform, Module, Scratch, ScratchOwned, VecZnx, ZnxViewMut}, + source::Source, + test_suite::convolution::bivariate_convolution_naive, +}; +use rand::RngCore; +use std::f64::consts::SQRT_2; + +use crate::{ + GLWEDecrypt, GLWEEncryptSk, GLWEMulConst, GLWEMulPlain, GLWESub, GLWETensorKeyEncryptSk, GLWETensoring, ScratchTakeCore, + layouts::{ + Dsize, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, + GLWESecretTensorPrepared, GLWETensor, GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPrepared, + GLWETensorKeyPreparedFactory, LWEInfos, TorusPrecision, prepared::GLWESecretPrepared, + }, +}; + +pub fn test_glwe_tensoring(module: &Module) +where + Module: GLWETensoring + + GLWEEncryptSk + + GLWEDecrypt + + VecZnxFillUniform + + GLWESecretPreparedFactory + + GLWESub + + VecZnxNormalizeInplace + + GLWESecretTensorFactory + + VecZnxCopy + + VecZnxNormalize + + GLWETensorKeyEncryptSk + + GLWETensorKeyPreparedFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let in_base2k: usize = 16; + let out_base2k: usize = 13; + let tsk_base2k: usize = 15; + let k: usize = 128; + + for rank in 1_usize..=3 { + let n: usize = module.n(); + + let glwe_in_infos: GLWELayout = GLWELayout { + n: n.into(), + base2k: in_base2k.into(), + k: k.into(), + rank: rank.into(), + }; + + let glwe_out_infos: GLWELayout = GLWELayout { + n: n.into(), + base2k: out_base2k.into(), + k: k.into(), + rank: rank.into(), + }; + + let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { + n: n.into(), + base2k: tsk_base2k.into(), + k: (k + tsk_base2k).into(), + rank: rank.into(), + dnum: k.div_ceil(tsk_base2k).into(), + dsize: Dsize(1), + }; + + let mut a: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); + let mut b: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); + let mut res_tensor: GLWETensor> = GLWETensor::alloc_from_infos(&glwe_out_infos); + let mut res_relin: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + let mut pt_in: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut pt_tmp: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos) + .max(GLWE::decrypt_tmp_bytes(module, &glwe_out_infos)) + .max(module.glwe_tensor_apply_tmp_bytes(&res_tensor, 0, &a, &b)) + .max(module.glwe_secret_tensor_prepare_tmp_bytes(rank.into())) + .max(module.glwe_tensor_relinearize_tmp_bytes(&res_relin, &res_tensor, &tsk_infos)), + ); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: GLWESecret> = GLWESecret::alloc(module.n().into(), rank.into()); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_dft.prepare(module, &sk); + + let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc(module.n().into(), rank.into()); + sk_tensor.prepare(module, &sk, scratch.borrow()); + + let mut sk_tensor_prep: GLWESecretTensorPrepared, BE> = GLWESecretTensorPrepared::alloc(module, rank.into()); + sk_tensor_prep.prepare(module, &sk_tensor); + + let mut tsk: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tsk_infos); + tsk.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); + + let mut tsk_prep: GLWETensorKeyPrepared, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk_infos); + tsk_prep.prepare(module, &tsk, scratch.borrow()); + + let scale: usize = 2 * in_base2k; + + let mut data = vec![0i64; n]; + for i in data.iter_mut() { + *i = (source_xa.next_i64() & 7) - 4; + } + + pt_in.encode_vec_i64(&data, TorusPrecision(scale as u32)); + + let mut pt_want_base2k_in = VecZnx::alloc(n, 1, pt_in.size()); + bivariate_convolution_naive( + module, + in_base2k, + 2, + &mut pt_want_base2k_in, + 0, + pt_in.data(), + 0, + pt_in.data(), + 0, + scratch.borrow(), + ); + + a.encrypt_sk(module, &pt_in, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); + b.encrypt_sk(module, &pt_in, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); + + for res_offset in 0..scale { + module.glwe_tensor_apply(&mut res_tensor, scale + res_offset, &a, &b, scratch.borrow()); + + res_tensor.decrypt(module, &mut pt_have, &sk_dft, &sk_tensor_prep, scratch.borrow()); + module.vec_znx_normalize( + pt_want.data_mut(), + out_base2k, + res_offset as i64, + 0, + &pt_want_base2k_in, + in_base2k, + 0, + scratch.borrow(), + ); + + module.glwe_sub(&mut pt_tmp, &pt_have, &pt_want); + module.vec_znx_normalize_inplace(pt_tmp.base2k().as_usize(), &mut pt_tmp.data, 0, scratch.borrow()); + + let noise_have: f64 = pt_tmp.stats().std().log2(); + let noise_want = -((k - scale - res_offset - module.log_n()) as f64 - ((rank - 1) as f64) / SQRT_2); + + assert!(noise_have - noise_want <= 0.5, "{} > {}", noise_have, noise_want); + + module.glwe_tensor_relinearize(&mut res_relin, &res_tensor, &tsk_prep, tsk_prep.size(), scratch.borrow()); + res_relin.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.glwe_sub(&mut pt_tmp, &pt_have, &pt_want); + module.vec_znx_normalize_inplace(pt_tmp.base2k().as_usize(), &mut pt_tmp.data, 0, scratch.borrow()); + + // We can reuse the same noise bound because the relinearization noise (which is additive) + // is much smaller than the tensoring noise (which is multiplicative) + let noise_have: f64 = pt_tmp.stats().std().log2(); + assert!(noise_have - noise_want <= 0.5, "{} > {}", noise_have, noise_want); + } + } +} + +pub fn test_glwe_mul_plain(module: &Module) +where + Module: GLWEEncryptSk + + GLWEDecrypt + + VecZnxFillUniform + + GLWESecretPreparedFactory + + GLWESub + + VecZnxNormalizeInplace + + VecZnxCopy + + VecZnxNormalize + + GLWEMulPlain, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let in_base2k: usize = 16; + let out_base2k: usize = 13; + let k: usize = 128; + + for rank in 1_usize..=3 { + let n: usize = module.n(); + + let glwe_in_infos: GLWELayout = GLWELayout { + n: n.into(), + base2k: in_base2k.into(), + k: k.into(), + rank: rank.into(), + }; + + let glwe_out_infos: GLWELayout = GLWELayout { + n: n.into(), + base2k: out_base2k.into(), + k: k.into(), + rank: rank.into(), + }; + + let mut a: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + let mut pt_a: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_b: GLWEPlaintext> = GLWEPlaintext::alloc(module.n().into(), in_base2k.into(), (2 * in_base2k).into()); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut pt_tmp: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos) + .max(GLWE::decrypt_tmp_bytes(module, &glwe_out_infos)) + .max(module.glwe_mul_plain_tmp_bytes(&res, 0, &a, &pt_b)), + ); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: GLWESecret> = GLWESecret::alloc(module.n().into(), rank.into()); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_dft.prepare(module, &sk); + + let scale: usize = 2 * in_base2k; + + pt_b.data_mut().fill_uniform(in_base2k, &mut source_xa); + pt_a.data_mut().fill_uniform(in_base2k, &mut source_xa); + + let mut pt_want_base2k_in = VecZnx::alloc(n, 1, pt_a.size() + pt_b.size()); + bivariate_convolution_naive( + module, + in_base2k, + 2, + &mut pt_want_base2k_in, + 0, + pt_a.data(), + 0, + pt_b.data(), + 0, + scratch.borrow(), + ); + + a.encrypt_sk(module, &pt_a, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); + + for res_offset in 0..scale { + module.glwe_mul_plain(&mut res, scale + res_offset, &a, &pt_b, scratch.borrow()); + + res.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); + module.vec_znx_normalize( + pt_want.data_mut(), + out_base2k, + res_offset as i64, + 0, + &pt_want_base2k_in, + in_base2k, + 0, + scratch.borrow(), + ); + + module.glwe_sub(&mut pt_tmp, &pt_have, &pt_want); + module.vec_znx_normalize_inplace(pt_tmp.base2k().as_usize(), &mut pt_tmp.data, 0, scratch.borrow()); + + let noise_have: f64 = pt_tmp.stats().std().log2(); + let noise_want = -((k - scale - res_offset - module.log_n()) as f64 - ((rank - 1) as f64) / SQRT_2); + + assert!(noise_have - noise_want <= 0.5, "{} > {}", noise_have, noise_want); + } + } +} + +pub fn test_glwe_mul_const(module: &Module) +where + Module: GLWEEncryptSk + + GLWEDecrypt + + VecZnxFillUniform + + GLWESecretPreparedFactory + + GLWESub + + VecZnxNormalizeInplace + + VecZnxCopy + + VecZnxNormalize + + GLWEMulConst, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let in_base2k: usize = 16; + let out_base2k: usize = 13; + let k: usize = 128; + let b_size = 3; + + for rank in 1_usize..=3 { + let n: usize = module.n(); + + let glwe_in_infos: GLWELayout = GLWELayout { + n: n.into(), + base2k: in_base2k.into(), + k: k.into(), + rank: rank.into(), + }; + + let glwe_out_infos: GLWELayout = GLWELayout { + n: n.into(), + base2k: out_base2k.into(), + k: k.into(), + rank: rank.into(), + }; + + let mut a: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + let mut pt_a: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_b: GLWEPlaintext> = GLWEPlaintext::alloc(module.n().into(), in_base2k.into(), (2 * in_base2k).into()); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut pt_tmp: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos) + .max(GLWE::decrypt_tmp_bytes(module, &glwe_out_infos)) + .max(module.glwe_mul_const_tmp_bytes(&res, 0, &a, b_size)), + ); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: GLWESecret> = GLWESecret::alloc(module.n().into(), rank.into()); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_dft.prepare(module, &sk); + + let scale: usize = 2 * in_base2k; + + pt_a.data_mut().fill_uniform(in_base2k, &mut source_xa); + + let mut b_const = vec![0i64; b_size]; + let mask = (1 << in_base2k) - 1; + for (j, x) in b_const[..1].iter_mut().enumerate() { + let r = source_xa.next_u64() & mask; + *x = ((r << (64 - in_base2k)) as i64) >> (64 - in_base2k); + pt_b.data_mut().at_mut(0, j)[0] = *x + } + + let mut pt_want_base2k_in = VecZnx::alloc(n, 1, pt_a.size() + pt_b.size()); + bivariate_convolution_naive( + module, + in_base2k, + 2, + &mut pt_want_base2k_in, + 0, + pt_a.data(), + 0, + pt_b.data(), + 0, + scratch.borrow(), + ); + + a.encrypt_sk(module, &pt_a, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); + + for res_offset in 0..scale { + module.glwe_mul_const(&mut res, scale + res_offset, &a, &b_const, scratch.borrow()); + + res.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); + module.vec_znx_normalize( + pt_want.data_mut(), + out_base2k, + res_offset as i64, + 0, + &pt_want_base2k_in, + in_base2k, + 0, + scratch.borrow(), + ); + + module.glwe_sub(&mut pt_tmp, &pt_have, &pt_want); + module.vec_znx_normalize_inplace(pt_tmp.base2k().as_usize(), &mut pt_tmp.data, 0, scratch.borrow()); + + let noise_have: f64 = pt_tmp.stats().std().log2(); + let noise_want = -((k - scale - res_offset - module.log_n()) as f64 - ((rank - 1) as f64) / SQRT_2); + + assert!(noise_have - noise_want <= 0.5, "{} > {}", noise_have, noise_want); + } + } +} diff --git a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs index 59310bf..e7510c6 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs @@ -26,27 +26,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_in: usize = 17; - let base2k_key: usize = 13; - let base2k_out: usize = base2k_in; // MUST BE SAME + let in_base2k: usize = 17; + let key_base2k: usize = 13; + let out_base2k: usize = in_base2k; // MUST BE SAME let k_in: usize = 102; - let max_dsize: usize = k_in.div_ceil(base2k_key); + let max_dsize: usize = k_in.div_ceil(key_base2k); for rank_in_s0s1 in 1_usize..2 { for rank_out_s0s1 in 1_usize..3 { for rank_out_s1s2 in 1_usize..3 { for dsize in 1_usize..max_dsize + 1 { - let k_ksk: usize = k_in + base2k_key * dsize; + let k_ksk: usize = k_in + key_base2k * dsize; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); let dsize_in: usize = 1; - let dnum_in: usize = k_in / base2k_in; - let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_in / in_base2k; + let dnum_ksk: usize = k_in.div_ceil(key_base2k * dsize); let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -56,7 +56,7 @@ where let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -66,7 +66,7 @@ where let gglwe_s0s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -108,43 +108,24 @@ where sk2_prepared.prepare(module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 - gglwe_s0s1.encrypt_sk( - module, - &sk0, - &sk1, - &mut source_xa, - &mut source_xe, - scratch_enc.borrow(), - ); + gglwe_s0s1.encrypt_sk(module, &sk0, &sk1, &mut source_xa, &mut source_xe, scratch_enc.borrow()); // gglwe_{s2}(s1) -> s1 -> s2 - gglwe_s1s2.encrypt_sk( - module, - &sk1, - &sk2, - &mut source_xa, - &mut source_xe, - scratch_enc.borrow(), - ); + gglwe_s1s2.encrypt_sk(module, &sk1, &sk2, &mut source_xa, &mut source_xe, scratch_enc.borrow()); let mut gglwe_s1s2_prepared: GLWESwitchingKeyPrepared, BE> = GLWESwitchingKeyPrepared::alloc_from_infos(module, &gglwe_s1s2); gglwe_s1s2_prepared.prepare(module, &gglwe_s1s2, scratch_apply.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - gglwe_s0s2.keyswitch( - module, - &gglwe_s0s1, - &gglwe_s1s2_prepared, - scratch_apply.borrow(), - ); + gglwe_s0s2.keyswitch(module, &gglwe_s0s1, &gglwe_s1s2_prepared, scratch_apply.borrow()); let max_noise: f64 = var_noise_gglwe_product_v2( module.n() as f64, k_ksk, dnum_ksk, dsize, - base2k_key, + key_base2k, 0.5, 0.5, 0f64, @@ -157,21 +138,12 @@ where for row in 0..gglwe_s0s2.dnum().as_usize() { for col in 0..gglwe_s0s2.rank_in().as_usize() { - assert!( - gglwe_s0s2 - .key - .noise( - module, - row, - col, - &sk0.data, - &sk2_prepared, - scratch_apply.borrow() - ) - .std() - .log2() - <= max_noise + 0.5 - ) + let noise: f64 = gglwe_s0s2 + .key + .noise(module, row, col, &sk0.data, &sk2_prepared, scratch_apply.borrow()) + .std() + .log2(); + assert!(noise <= max_noise + 0.5, "{noise} > {max_noise}",) } } } @@ -191,25 +163,25 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 17; - let base2k_key: usize = 13; + let out_base2k: usize = 17; + let key_base2k: usize = 13; let k_out: usize = 102; - let max_dsize: usize = k_out.div_ceil(base2k_key); + let max_dsize: usize = k_out.div_ceil(key_base2k); for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { for dsize in 1_usize..max_dsize + 1 { - let k_ksk: usize = k_out + base2k_key * dsize; + let k_ksk: usize = k_out + key_base2k * dsize; let n: usize = module.n(); let dsize_in: usize = 1; - let dnum_in: usize = k_out / base2k_out; - let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_out / out_base2k; + let dnum_ksk: usize = k_out.div_ceil(key_base2k * dsize); let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -219,7 +191,7 @@ where let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -260,24 +232,10 @@ where sk2_prepared.prepare(module, &sk2); // gglwe_{s1}(s0) = s0 -> s1 - gglwe_s0s1.encrypt_sk( - module, - &sk0, - &sk1, - &mut source_xa, - &mut source_xe, - scratch_enc.borrow(), - ); + gglwe_s0s1.encrypt_sk(module, &sk0, &sk1, &mut source_xa, &mut source_xe, scratch_enc.borrow()); // gglwe_{s2}(s1) -> s1 -> s2 - gglwe_s1s2.encrypt_sk( - module, - &sk1, - &sk2, - &mut source_xa, - &mut source_xe, - scratch_enc.borrow(), - ); + gglwe_s1s2.encrypt_sk(module, &sk1, &sk2, &mut source_xa, &mut source_xe, scratch_enc.borrow()); let mut gglwe_s1s2_prepared: GLWESwitchingKeyPrepared, BE> = GLWESwitchingKeyPrepared::alloc_from_infos(module, &gglwe_s1s2); @@ -290,7 +248,7 @@ where let max_noise: f64 = log2_std_noise_gglwe_product( n as f64, - base2k_key * dsize, + key_base2k * dsize, var_xs, var_xs, 0f64, @@ -303,21 +261,12 @@ where for row in 0..gglwe_s0s2.dnum().as_usize() { for col in 0..gglwe_s0s2.rank_in().as_usize() { - assert!( - gglwe_s0s2 - .key - .noise( - module, - row, - col, - &sk0.data, - &sk2_prepared, - scratch_apply.borrow() - ) - .std() - .log2() - <= max_noise + 0.5 - ) + let noise = gglwe_s0s2 + .key + .noise(module, row, col, &sk0.data, &sk2_prepared, scratch_apply.borrow()) + .std() + .log2(); + assert!(noise <= max_noise + 0.5, "{noise} > {max_noise}") } } } diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs index f1e33d5..f8c1677 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -8,9 +8,9 @@ use crate::{ GGLWEToGGSWKeyEncryptSk, GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWEToGGSWKey, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWInfos, GGSWLayout, GLWEInfos, - GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GLWESwitchingKeyPreparedFactory, - GLWETensorKeyLayout, + GGLWEToGGSWKey, GGLWEToGGSWKeyLayout, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWInfos, GGSWLayout, + GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, + GLWESwitchingKeyPreparedFactory, prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, noise::noise_ggsw_keyswitch, @@ -30,27 +30,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_in: usize = 17; - let base2k_key: usize = 13; - let base2k_out: usize = base2k_in; // MUST BE SAME + let in_base2k: usize = 17; + let key_base2k: usize = 13; + let out_base2k: usize = in_base2k; // MUST BE SAME let k_in: usize = 102; - let max_dsize: usize = k_in.div_ceil(base2k_key); + let max_dsize: usize = k_in.div_ceil(key_base2k); for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ksk: usize = k_in + base2k_key * dsize; + let k_ksk: usize = k_in + key_base2k * dsize; let k_tsk: usize = k_ksk; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum_in: usize = k_in / base2k_in; - let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_in / in_base2k; + let dnum_ksk: usize = k_in.div_ceil(key_base2k * dsize); let dsize_in: usize = 1; let ggsw_in_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -59,16 +59,16 @@ where let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; - let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { + let tsk_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_tsk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -77,7 +77,7 @@ where let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -99,13 +99,7 @@ where GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos) - | GGSW::keyswitch_tmp_bytes( - module, - &ggsw_out_infos, - &ggsw_in_infos, - &ksk_apply_infos, - &tsk_infos, - ), + | GGSW::keyswitch_tmp_bytes(module, &ggsw_out_infos, &ggsw_in_infos, &ksk_apply_infos, &tsk_infos), ); let var_xs: f64 = 0.5; @@ -122,21 +116,8 @@ where let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); sk_out_prepared.prepare(module, &sk_out); - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - tsk.encrypt_sk( - module, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); + tsk.encrypt_sk(module, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); pt_scalar.fill_ternary_hw(0, n, &mut source_xs); @@ -156,18 +137,12 @@ where let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); tsk_prepared.prepare(module, &tsk, scratch.borrow()); - ggsw_out.keyswitch( - module, - &ggsw_in, - &ksk_prepared, - &tsk_prepared, - scratch.borrow(), - ); + ggsw_out.keyswitch(module, &ggsw_in, &ksk_prepared, &tsk_prepared, scratch.borrow()); let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k_key * dsize, + key_base2k * dsize, col_j, var_xs, 0f64, @@ -184,14 +159,7 @@ where for col in 0..ggsw_out.rank().as_usize() + 1 { assert!( ggsw_out - .noise( - module, - row, - col, - &pt_scalar, - &sk_out_prepared, - scratch.borrow() - ) + .noise(module, row, col, &pt_scalar, &sk_out_prepared, scratch.borrow()) .std() .log2() <= max_noise(col) @@ -216,33 +184,33 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 17; - let base2k_key: usize = 13; + let out_base2k: usize = 17; + let key_base2k: usize = 13; let k_out: usize = 102; - let max_dsize: usize = k_out.div_ceil(base2k_key); + let max_dsize: usize = k_out.div_ceil(key_base2k); for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ksk: usize = k_out + base2k_key * dsize; + let k_ksk: usize = k_out + key_base2k * dsize; let k_tsk: usize = k_ksk; let n: usize = module.n(); - let dnum_in: usize = k_out / base2k_out; - let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_out / out_base2k; + let dnum_ksk: usize = k_out.div_ceil(key_base2k * dsize); let dsize_in: usize = 1; let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; - let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { + let tsk_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_tsk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -251,7 +219,7 @@ where let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum_ksk.into(), dsize: dsize.into(), @@ -272,13 +240,7 @@ where GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos) - | GGSW::keyswitch_tmp_bytes( - module, - &ggsw_out_infos, - &ggsw_out_infos, - &ksk_apply_infos, - &tsk_infos, - ), + | GGSW::keyswitch_tmp_bytes(module, &ggsw_out_infos, &ggsw_out_infos, &ksk_apply_infos, &tsk_infos), ); let var_xs: f64 = 0.5; @@ -295,21 +257,8 @@ where let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); sk_out_prepared.prepare(module, &sk_out); - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - tsk.encrypt_sk( - module, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); + tsk.encrypt_sk(module, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); pt_scalar.fill_ternary_hw(0, n, &mut source_xs); @@ -334,7 +283,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k_key * dsize, + key_base2k * dsize, col_j, var_xs, 0f64, @@ -351,14 +300,7 @@ where for col in 0..ggsw_out.rank().as_usize() + 1 { assert!( ggsw_out - .noise( - module, - row, - col, - &pt_scalar, - &sk_out_prepared, - scratch.borrow() - ) + .noise(module, row, col, &pt_scalar, &sk_out_prepared, scratch.borrow()) .std() .log2() <= max_noise(col) diff --git a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs index c4ac553..e8158ae 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs @@ -29,38 +29,38 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_in: usize = 17; - let base2k_key: usize = 13; - let base2k_out: usize = 15; + let in_base2k: usize = 17; + let key_base2k: usize = 13; + let out_base2k: usize = 15; let k_in: usize = 102; - let max_dsize: usize = k_in.div_ceil(base2k_key); + let max_dsize: usize = k_in.div_ceil(key_base2k); for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { for dsize in 1_usize..max_dsize + 1 { - let k_ksk: usize = k_in + base2k_key * dsize; + let k_ksk: usize = k_in + key_base2k * dsize; let k_out: usize = k_ksk; // better capture noise let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k_key * dsize); + let dnum: usize = k_in.div_ceil(key_base2k * dsize); let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_in.into(), rank: rank_in.into(), }; let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), rank: rank_out.into(), }; let ksk: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum.into(), dsize: dsize.into(), @@ -98,14 +98,7 @@ where let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); sk_out_prepared.prepare(module, &sk_out); - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); glwe_in.encrypt_sk( module, @@ -127,7 +120,7 @@ where k_ksk, dnum, dsize, - base2k_key, + key_base2k, 0.5, 0.5, 0f64, @@ -164,28 +157,28 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 17; - let base2k_key: usize = 13; + let out_base2k: usize = 17; + let key_base2k: usize = 13; let k_out: usize = 102; - let max_dsize: usize = k_out.div_ceil(base2k_key); + let max_dsize: usize = k_out.div_ceil(key_base2k); for rank in 1_usize..3 { for dsize in 1..max_dsize + 1 { - let k_ksk: usize = k_out + base2k_key * dsize; + let k_ksk: usize = k_out + key_base2k * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k_key * dsize); + let dnum: usize = k_out.div_ceil(key_base2k * dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k_out.into(), rank: rank.into(), }; let ksk_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum.into(), dsize: dsize.into(), @@ -201,12 +194,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform( - pt_want.base2k().into(), - &mut pt_want.data, - 0, - &mut source_xa, - ); + module.vec_znx_fill_uniform(pt_want.base2k().into(), &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_infos) @@ -226,14 +214,7 @@ where let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); sk_out_prepared.prepare(module, &sk_out); - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow()); glwe_out.encrypt_sk( module, @@ -255,7 +236,7 @@ where k_ksk, dnum, dsize, - base2k_key, + key_base2k, 0.5, 0.5, 0f64, diff --git a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs index cd06749..347e7a8 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs @@ -24,17 +24,17 @@ where Scratch: ScratchAvailable + ScratchTakeCore, { let n: usize = module.n(); - let base2k_in: usize = 17; - let base2k_out: usize = 15; - let base2k_key: usize = 13; + let in_base2k: usize = 17; + let out_base2k: usize = 15; + let key_base2k: usize = 13; let n_lwe_in: usize = module.n() >> 1; let n_lwe_out: usize = module.n() >> 1; let k_lwe_ct: usize = 102; let k_lwe_pt: usize = 8; - let k_ksk: usize = k_lwe_ct + base2k_key; - let dnum: usize = k_lwe_ct.div_ceil(base2k_key); + let k_ksk: usize = k_lwe_ct + key_base2k; + let dnum: usize = k_lwe_ct.div_ceil(key_base2k); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); @@ -42,21 +42,21 @@ where let key_apply_infos: LWESwitchingKeyLayout = LWESwitchingKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_ksk.into(), dnum: dnum.into(), }; let lwe_in_infos: LWELayout = LWELayout { n: n_lwe_in.into(), - base2k: base2k_in.into(), + base2k: in_base2k.into(), k: k_lwe_ct.into(), }; let lwe_out_infos: LWELayout = LWELayout { n: n_lwe_out.into(), k: k_lwe_ct.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), }; let mut scratch: ScratchOwned = ScratchOwned::alloc( @@ -72,7 +72,7 @@ where let data: i64 = 17; - let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(base2k_in.into(), k_lwe_pt.into()); + let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(in_base2k.into(), k_lwe_pt.into()); lwe_pt_in.encode_i64(data, k_lwe_pt.into()); let mut lwe_ct_in: LWE> = LWE::alloc_from_infos(&lwe_in_infos); @@ -108,11 +108,12 @@ where let mut lwe_pt_want: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_out_infos); module.vec_znx_normalize( - base2k_out, lwe_pt_want.data_mut(), + out_base2k, + 0, 0, - base2k_in, lwe_pt_in.data(), + in_base2k, 0, scratch.borrow(), ); diff --git a/poulpy-core/src/tests/test_suite/mod.rs b/poulpy-core/src/tests/test_suite/mod.rs index 6b086e7..33b5688 100644 --- a/poulpy-core/src/tests/test_suite/mod.rs +++ b/poulpy-core/src/tests/test_suite/mod.rs @@ -1,6 +1,7 @@ pub mod automorphism; pub mod encryption; pub mod external_product; +pub mod glwe_tensor; pub mod keyswitch; mod conversion; diff --git a/poulpy-core/src/tests/test_suite/trace.rs b/poulpy-core/src/tests/test_suite/trace.rs index b57dee7..5ab1b69 100644 --- a/poulpy-core/src/tests/test_suite/trace.rs +++ b/poulpy-core/src/tests/test_suite/trace.rs @@ -32,27 +32,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k_out: usize = 15; - let base2k_key: usize = 10; + let out_base2k: usize = 15; + let key_base2k: usize = 10; let k: usize = 54; for rank in 1_usize..3 { let n: usize = module.n(); - let k_autokey: usize = k + base2k_key; + let k_autokey: usize = k + key_base2k; let dsize: usize = 1; - let dnum: usize = k.div_ceil(base2k_key * dsize); + let dnum: usize = k.div_ceil(key_base2k * dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k_out.into(), + base2k: out_base2k.into(), k: k.into(), rank: rank.into(), }; let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k_key.into(), + base2k: key_base2k.into(), k: k_autokey.into(), rank: rank.into(), dsize: dsize.into(), @@ -82,33 +82,17 @@ where let mut data_want: Vec = vec![0i64; n]; - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0xFF); + data_want.iter_mut().for_each(|x| *x = source_xa.next_i64() & 0xFF); - module.vec_znx_fill_uniform(base2k_out, &mut pt_have.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(out_base2k, &mut pt_have.data, 0, &mut source_xa); - glwe_out.encrypt_sk( - module, - &pt_have, - &sk_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + glwe_out.encrypt_sk(module, &pt_have, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow()); let mut auto_keys: HashMap, BE>> = HashMap::new(); let gal_els: Vec = GLWE::trace_galois_elements(module); let mut tmp: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&key_infos); gal_els.iter().for_each(|gal_el| { - tmp.encrypt_sk( - module, - *gal_el, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + tmp.encrypt_sk(module, *gal_el, &sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut atk_prepared: GLWEAutomorphismKeyPrepared, BE> = GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp); atk_prepared.prepare(module, &tmp, scratch.borrow()); @@ -122,18 +106,13 @@ where glwe_out.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_inplace(&mut pt_want.data, 0, &pt_have.data, 0); - module.vec_znx_normalize_inplace( - pt_want.base2k().as_usize(), - &mut pt_want.data, - 0, - scratch.borrow(), - ); + module.vec_znx_normalize_inplace(pt_want.base2k().as_usize(), &mut pt_want.data, 0, scratch.borrow()); let noise_have: f64 = pt_want.stats().std().log2(); let mut noise_want: f64 = var_noise_gglwe_product( n as f64, - base2k_key * dsize, + key_base2k * dsize, 0.5, 0.5, 1.0 / 12.0, @@ -147,9 +126,6 @@ where noise_want += n as f64 * 1.0 / 12.0 * 0.5 * rank as f64 * (-2.0 * (k) as f64).exp2(); noise_want = noise_want.sqrt().log2(); - assert!( - (noise_have - noise_want).abs() < 1.0, - "{noise_have} > {noise_want}" - ); + assert!((noise_have - noise_want).abs() < 1.0, "{noise_have} > {noise_want}"); } } diff --git a/poulpy-core/src/utils.rs b/poulpy-core/src/utils.rs index 7484fcd..ce81386 100644 --- a/poulpy-core/src/utils.rs +++ b/poulpy-core/src/utils.rs @@ -16,13 +16,11 @@ impl GLWEPlaintext { impl GLWEPlaintext { pub fn decode_vec_i64(&self, data: &mut [i64], k: TorusPrecision) { - self.data - .decode_vec_i64(self.base2k().into(), 0, k.into(), data); + self.data.decode_vec_i64(self.base2k().into(), 0, k.into(), data); } pub fn decode_coeff_i64(&self, k: TorusPrecision, idx: usize) -> i64 { - self.data - .decode_coeff_i64(self.base2k().into(), 0, k.into(), idx) + self.data.decode_coeff_i64(self.base2k().into(), 0, k.into(), idx) } pub fn decode_vec_float(&self, data: &mut [Float]) { @@ -43,14 +41,12 @@ impl LWEPlaintext { impl LWEPlaintext { pub fn decode_i64(&self, k: TorusPrecision) -> i64 { - self.data - .decode_coeff_i64(self.base2k().into(), 0, k.into(), 0) + self.data.decode_coeff_i64(self.base2k().into(), 0, k.into(), 0) } pub fn decode_float(&self) -> Float { let mut out: [Float; 1] = [Float::new(self.k().as_u32())]; - self.data - .decode_vec_float(self.base2k().into(), 0, &mut out); + self.data.decode_vec_float(self.base2k().into(), 0, &mut out); out[0].clone() } } diff --git a/poulpy-cpu-avx/Cargo.toml b/poulpy-cpu-avx/Cargo.toml index 5704b1e..6f6aab2 100644 --- a/poulpy-cpu-avx/Cargo.toml +++ b/poulpy-cpu-avx/Cargo.toml @@ -32,5 +32,5 @@ rustdoc-args = ["--cfg", "docsrs"] [[bench]] -name = "vmp" +name = "convolution" harness = false \ No newline at end of file diff --git a/poulpy-cpu-avx/benches/convolution.rs b/poulpy-cpu-avx/benches/convolution.rs new file mode 100644 index 0000000..67e61ff --- /dev/null +++ b/poulpy-cpu-avx/benches/convolution.rs @@ -0,0 +1,116 @@ +use criterion::{Criterion, criterion_group, criterion_main}; + +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] +fn bench_cnv_prepare_left_cpu_avx_fft64(_c: &mut Criterion) { + eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); +} + +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] +fn bench_cnv_prepare_left_cpu_avx_fft64(c: &mut Criterion) { + use poulpy_cpu_avx::FFT64Avx; + poulpy_hal::bench_suite::convolution::bench_cnv_prepare_left::(c, "cpu_avx::fft64"); +} + +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] +fn bench_cnv_prepare_right_cpu_avx_fft64(_c: &mut Criterion) { + eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); +} + +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] +fn bench_cnv_prepare_right_cpu_avx_fft64(c: &mut Criterion) { + use poulpy_cpu_avx::FFT64Avx; + poulpy_hal::bench_suite::convolution::bench_cnv_prepare_right::(c, "cpu_avx::fft64"); +} + +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] +fn bench_cnv_apply_dft_cpu_avx_fft64(_c: &mut Criterion) { + eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); +} + +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] +fn bench_cnv_apply_dft_cpu_avx_fft64(c: &mut Criterion) { + use poulpy_cpu_avx::FFT64Avx; + poulpy_hal::bench_suite::convolution::bench_cnv_apply_dft::(c, "cpu_avx::fft64"); +} + +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] +fn bench_cnv_pairwise_apply_dft_cpu_avx_fft64(_c: &mut Criterion) { + eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); +} + +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] +fn bench_cnv_pairwise_apply_dft_cpu_avx_fft64(c: &mut Criterion) { + use poulpy_cpu_avx::FFT64Avx; + poulpy_hal::bench_suite::convolution::bench_cnv_pairwise_apply_dft::(c, "cpu_avx::fft64"); +} + +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] +fn bench_cnv_by_const_apply_cpu_avx_fft64(_c: &mut Criterion) { + eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); +} + +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] +fn bench_cnv_by_const_apply_cpu_avx_fft64(c: &mut Criterion) { + use poulpy_cpu_avx::FFT64Avx; + poulpy_hal::bench_suite::convolution::bench_cnv_by_const_apply::(c, "cpu_avx::fft64"); +} + +criterion_group!( + benches, + bench_cnv_prepare_left_cpu_avx_fft64, + bench_cnv_prepare_right_cpu_avx_fft64, + bench_cnv_apply_dft_cpu_avx_fft64, + bench_cnv_pairwise_apply_dft_cpu_avx_fft64, + bench_cnv_by_const_apply_cpu_avx_fft64, +); +criterion_main!(benches); diff --git a/poulpy-cpu-avx/benches/fft.rs b/poulpy-cpu-avx/benches/fft.rs index 2e8e542..5985941 100644 --- a/poulpy-cpu-avx/benches/fft.rs +++ b/poulpy-cpu-avx/benches/fft.rs @@ -1,11 +1,21 @@ use criterion::{Criterion, criterion_group, criterion_main}; -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] fn bench_ifft_avx2_fma(_c: &mut Criterion) { eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); } -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] pub fn bench_ifft_avx2_fma(c: &mut Criterion) { use criterion::BenchmarkId; use poulpy_cpu_avx::ReimIFFTAvx; @@ -21,10 +31,7 @@ pub fn bench_ifft_avx2_fma(c: &mut Criterion) { let mut values: Vec = vec![0f64; m << 1]; let scale = 1.0f64 / (2 * m) as f64; - values - .iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale); let table: ReimIFFTTable = ReimIFFTTable::::new(m); move || { @@ -47,12 +54,22 @@ pub fn bench_ifft_avx2_fma(c: &mut Criterion) { group.finish(); } -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] fn bench_fft_avx2_fma(_c: &mut Criterion) { eprintln!("Skipping: AVX FFT benchmark requires x86_64 + AVX2 + FMA"); } -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] pub fn bench_fft_avx2_fma(c: &mut Criterion) { use criterion::BenchmarkId; use poulpy_cpu_avx::ReimFFTAvx; @@ -68,10 +85,7 @@ pub fn bench_fft_avx2_fma(c: &mut Criterion) { let mut values: Vec = vec![0f64; m << 1]; let scale = 1.0f64 / (2 * m) as f64; - values - .iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale); let table: ReimFFTTable = ReimFFTTable::::new(m); move || { diff --git a/poulpy-cpu-avx/benches/vec_znx.rs b/poulpy-cpu-avx/benches/vec_znx.rs index ef975ab..5182c7e 100644 --- a/poulpy-cpu-avx/benches/vec_znx.rs +++ b/poulpy-cpu-avx/benches/vec_znx.rs @@ -1,33 +1,63 @@ use criterion::{Criterion, criterion_group, criterion_main}; -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] fn bench_vec_znx_add_cpu_avx_fft64(_c: &mut Criterion) { eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); } -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] fn bench_vec_znx_add_cpu_avx_fft64(c: &mut Criterion) { use poulpy_cpu_avx::FFT64Avx; poulpy_hal::reference::vec_znx::bench_vec_znx_add::(c, "FFT64Avx"); } -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] fn bench_vec_znx_normalize_inplace_cpu_avx_fft64(_c: &mut Criterion) { eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); } -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] fn bench_vec_znx_normalize_inplace_cpu_avx_fft64(c: &mut Criterion) { use poulpy_cpu_avx::FFT64Avx; poulpy_hal::reference::vec_znx::bench_vec_znx_normalize_inplace::(c, "FFT64Avx"); } -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] fn bench_vec_znx_automorphism_cpu_avx_fft64(_c: &mut Criterion) { eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); } -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] fn bench_vec_znx_automorphism_cpu_avx_fft64(c: &mut Criterion) { use poulpy_cpu_avx::FFT64Avx; poulpy_hal::reference::vec_znx::bench_vec_znx_automorphism::(c, "FFT64Avx"); diff --git a/poulpy-cpu-avx/benches/vmp.rs b/poulpy-cpu-avx/benches/vmp.rs index e1d6a65..61b4664 100644 --- a/poulpy-cpu-avx/benches/vmp.rs +++ b/poulpy-cpu-avx/benches/vmp.rs @@ -1,11 +1,21 @@ use criterion::{Criterion, criterion_group, criterion_main}; -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(_c: &mut Criterion) { eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); } -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(c: &mut Criterion) { use poulpy_cpu_avx::FFT64Avx; poulpy_hal::bench_suite::vmp::bench_vmp_apply_dft_to_dft::(c, "FFT64Avx"); diff --git a/poulpy-cpu-avx/examples/rlwe_encrypt.rs b/poulpy-cpu-avx/examples/rlwe_encrypt.rs index 2cc51a9..d7dcd2c 100644 --- a/poulpy-cpu-avx/examples/rlwe_encrypt.rs +++ b/poulpy-cpu-avx/examples/rlwe_encrypt.rs @@ -1,8 +1,18 @@ use itertools::izip; -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] use poulpy_cpu_avx::FFT64Avx as BackendImpl; -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] use poulpy_cpu_ref::FFT64Ref as BackendImpl; use poulpy_hal::{ @@ -73,8 +83,7 @@ fn main() { msg_size, // Number of small polynomials ); let mut want: Vec = vec![0; n]; - want.iter_mut() - .for_each(|x| *x = source.next_u64n(16, 15) as i64); + want.iter_mut().for_each(|x| *x = source.next_u64n(16, 15) as i64); m.encode_vec_i64(base2k, 0, log_scale, &want); module.vec_znx_normalize_inplace(base2k, &mut m, 0, scratch.borrow()); @@ -89,11 +98,12 @@ fn main() { // Normalizes back to VecZnx // ct[0] <- m - BIG(c1 * s) module.vec_znx_big_normalize( - base2k, &mut ct, - 0, // Selects the first column of ct (ct[0]) base2k, + 0, + 0, // Selects the first column of ct (ct[0]) &buf_big, + base2k, 0, // Selects the first column of buf_big scratch.borrow(), ); @@ -131,15 +141,13 @@ fn main() { // m + e <- BIG(ct[1] * s + ct[0]) let mut res = VecZnx::alloc(module.n(), 1, ct_size); - module.vec_znx_big_normalize(base2k, &mut res, 0, base2k, &buf_big, 0, scratch.borrow()); + module.vec_znx_big_normalize(&mut res, base2k, 0, 0, &buf_big, base2k, 0, scratch.borrow()); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; res.decode_vec_i64(base2k, 0, ct_size * base2k, &mut have); let scale: f64 = (1 << (res.size() * base2k - log_scale)) as f64; - izip!(want.iter(), have.iter()) - .enumerate() - .for_each(|(i, (a, b))| { - println!("{}: {} {}", i, a, (*b as f64) / scale); - }); + izip!(want.iter(), have.iter()).enumerate().for_each(|(i, (a, b))| { + println!("{}: {} {}", i, a, (*b as f64) / scale); + }); } diff --git a/poulpy-cpu-avx/src/convolution.rs b/poulpy-cpu-avx/src/convolution.rs new file mode 100644 index 0000000..9dddd28 --- /dev/null +++ b/poulpy-cpu-avx/src/convolution.rs @@ -0,0 +1,401 @@ +use poulpy_hal::{ + api::{Convolution, ModuleN, ScratchTakeBasic, TakeSlice, VecZnxDftApply, VecZnxDftBytesOf}, + layouts::{ + Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnx, + VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos, + }, + oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl}, + reference::fft64::convolution::{ + convolution_apply_dft, convolution_apply_dft_tmp_bytes, convolution_by_const_apply, convolution_by_const_apply_tmp_bytes, + convolution_pairwise_apply_dft, convolution_pairwise_apply_dft_tmp_bytes, convolution_prepare_left, + convolution_prepare_right, + }, +}; + +use crate::{FFT64Avx, module::FFT64ModuleHandle}; + +unsafe impl CnvPVecLAllocImpl for FFT64Avx { + fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL, Self> { + CnvPVecL::alloc(n, cols, size) + } + + fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR, Self> { + CnvPVecR::alloc(n, cols, size) + } +} + +unsafe impl CnvPVecBytesOfImpl for FFT64Avx { + fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize { + Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() + } + + fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize { + Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() + } +} + +unsafe impl ConvolutionImpl for FFT64Avx +where + Module: ModuleN + VecZnxDftBytesOf + VecZnxDftApply, +{ + fn cnv_prepare_left_tmp_bytes_impl(module: &Module, res_size: usize, a_size: usize) -> usize { + module.bytes_of_vec_znx_dft(1, res_size.min(a_size)) + } + + fn cnv_prepare_left_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: CnvPVecLToMut, + A: VecZnxToRef, + { + let res: &mut CnvPVecL<&mut [u8], FFT64Avx> = &mut res.to_mut(); + let a: &VecZnx<&[u8]> = &a.to_ref(); + let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size())); + convolution_prepare_left(module.get_fft_table(), res, a, &mut tmp); + } + + fn cnv_prepare_right_tmp_bytes_impl(module: &Module, res_size: usize, a_size: usize) -> usize { + module.bytes_of_vec_znx_dft(1, res_size.min(a_size)) + } + + fn cnv_prepare_right_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: CnvPVecRToMut, + A: VecZnxToRef, + { + let res: &mut CnvPVecR<&mut [u8], FFT64Avx> = &mut res.to_mut(); + let a: &VecZnx<&[u8]> = &a.to_ref(); + let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size())); + convolution_prepare_right(module.get_fft_table(), res, a, &mut tmp); + } + fn cnv_apply_dft_tmp_bytes_impl( + _module: &Module, + res_size: usize, + _res_offset: usize, + a_size: usize, + b_size: usize, + ) -> usize { + convolution_apply_dft_tmp_bytes(res_size, a_size, b_size) + } + + fn cnv_by_const_apply_tmp_bytes_impl( + _module: &Module, + res_size: usize, + _res_offset: usize, + a_size: usize, + b_size: usize, + ) -> usize { + convolution_by_const_apply_tmp_bytes(res_size, a_size, b_size) + } + + fn cnv_by_const_apply_impl( + module: &Module, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &[i64], + scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let res: &mut VecZnxBig<&mut [u8], Self> = &mut res.to_mut(); + let a: &VecZnx<&[u8]> = &a.to_ref(); + let (tmp, _) = + scratch.take_slice(module.cnv_by_const_apply_tmp_bytes(res.size(), res_offset, a.size(), b.len()) / size_of::()); + convolution_by_const_apply(res, res_offset, res_col, a, a_col, b, tmp); + } + + fn cnv_apply_dft_impl( + module: &Module, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef, + { + let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut(); + let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref(); + let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref(); + let (tmp, _) = + scratch.take_slice(module.cnv_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::()); + convolution_apply_dft(res, res_offset, res_col, a, a_col, b, b_col, tmp); + } + + fn cnv_pairwise_apply_dft_tmp_bytes( + _module: &Module, + res_size: usize, + _res_offset: usize, + a_size: usize, + b_size: usize, + ) -> usize { + convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size) + } + + fn cnv_pairwise_apply_dft_impl( + module: &Module, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + b: &B, + i: usize, + j: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef, + { + let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut(); + let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref(); + let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref(); + let (tmp, _) = scratch + .take_slice(module.cnv_pairwise_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::()); + convolution_pairwise_apply_dft(res, res_offset, res_col, a, b, i, j, tmp); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2. +/// Assumes all inputs fit in i32 (so i32×i32→i64 is exact). +#[target_feature(enable = "avx2")] +pub unsafe fn i64_convolution_by_const_1coeff_avx(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) { + use core::arch::x86_64::{ + __m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256, + _mm256_storeu_si256, + }; + + dst.fill(0); + + let b_size = b.len(); + if k >= a_size + b_size { + return; + } + + let j_min = k.saturating_sub(a_size - 1); + let j_max = (k + 1).min(b_size); + + unsafe { + // Two accumulators = 8 outputs total + let mut acc_lo: __m256i = _mm256_setzero_si256(); // dst[0..4) + let mut acc_hi: __m256i = _mm256_setzero_si256(); // dst[4..8) + + let mut a_ptr: *const i64 = a.as_ptr().add(8 * (k - j_min)); + let mut b_ptr: *const i64 = b.as_ptr().add(j_min); + + for _ in 0..(j_max - j_min) { + // Broadcast scalar b[j] as i32 + let br: __m256i = _mm256_set1_epi32(*b_ptr as i32); + + // ---- lower half: a[0..4) ---- + let a_lo: __m256i = _mm256_loadu_si256(a_ptr as *const __m256i); + + let prod_lo: __m256i = _mm256_mul_epi32(a_lo, br); + + acc_lo = _mm256_add_epi64(acc_lo, prod_lo); + + // ---- upper half: a[4..8) ---- + let a_hi: __m256i = _mm256_loadu_si256(a_ptr.add(4) as *const __m256i); + + let prod_hi: __m256i = _mm256_mul_epi32(a_hi, br); + + acc_hi = _mm256_add_epi64(acc_hi, prod_hi); + + a_ptr = a_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + + // Store final result + _mm256_storeu_si256(dst.as_mut_ptr() as *mut __m256i, acc_lo); + _mm256_storeu_si256(dst.as_mut_ptr().add(4) as *mut __m256i, acc_hi); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2. +/// Assumes all values in `a` and `b` fit in i32 (so i32×i32→i64 is exact). +#[target_feature(enable = "avx2")] +pub unsafe fn i64_convolution_by_real_const_2coeffs_avx( + k: usize, + dst: &mut [i64; 16], + a: &[i64], + a_size: usize, + b: &[i64], // real scalars, stride-1 +) { + use core::arch::x86_64::{ + __m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256, + _mm256_storeu_si256, + }; + + let b_size: usize = b.len(); + + debug_assert!(a.len() >= 8 * a_size); + + let k0: usize = k; + let k1: usize = k + 1; + let bound: usize = a_size + b_size; + + if k0 >= bound { + unsafe { + let zero: __m256i = _mm256_setzero_si256(); + let dst_ptr: *mut i64 = dst.as_mut_ptr(); + _mm256_storeu_si256(dst_ptr as *mut __m256i, zero); + _mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, zero); + _mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, zero); + _mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, zero); + } + return; + } + + unsafe { + let mut acc_lo_k0: __m256i = _mm256_setzero_si256(); + let mut acc_hi_k0: __m256i = _mm256_setzero_si256(); + let mut acc_lo_k1: __m256i = _mm256_setzero_si256(); + let mut acc_hi_k1: __m256i = _mm256_setzero_si256(); + + let j0_min: usize = (k0 + 1).saturating_sub(a_size); + let j0_max: usize = (k0 + 1).min(b_size); + + if k1 >= bound { + let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min)); + let mut b_ptr: *const i64 = b.as_ptr().add(j0_min); + + // Contributions to k0 only + for _ in 0..j0_max - j0_min { + // Broadcast b[j] as i32 + let br: __m256i = _mm256_set1_epi32(*b_ptr as i32); + + // Load 4×i64 (low half) and 4×i64 (high half) + let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i); + let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i); + + acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br)); + acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br)); + + a_k0_ptr = a_k0_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + } else { + let j1_min: usize = (k1 + 1).saturating_sub(a_size); + let j1_max: usize = (k1 + 1).min(b_size); + + let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min)); + let mut a_k1_ptr: *const i64 = a.as_ptr().add(8 * (k1 - j1_min)); + let mut b_ptr: *const i64 = b.as_ptr().add(j0_min); + + // Region 1: k0 only, j ∈ [j0_min, j1_min) + for _ in 0..j1_min - j0_min { + let br: __m256i = _mm256_set1_epi32(*b_ptr as i32); + + let a_k0_lo: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i); + let a_k0_hi: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i); + + acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_k0_lo, br)); + acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_k0_hi, br)); + + a_k0_ptr = a_k0_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + + // Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max) + // Save one load on b: broadcast once and reuse. + for _ in 0..j0_max - j1_min { + let br: __m256i = _mm256_set1_epi32(*b_ptr as i32); + + let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i); + let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i); + let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i); + let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i); + + // k0 + acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br)); + acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br)); + + // k1 + acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br)); + acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br)); + + a_k0_ptr = a_k0_ptr.sub(8); + a_k1_ptr = a_k1_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + + // Region 3: k1 only, j ∈ [j0_max, j1_max) + for _ in 0..j1_max - j0_max { + let br: __m256i = _mm256_set1_epi32(*b_ptr as i32); + + let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i); + let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i); + + acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br)); + acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br)); + + a_k1_ptr = a_k1_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + } + + let dst_ptr: *mut i64 = dst.as_mut_ptr(); + _mm256_storeu_si256(dst_ptr as *mut __m256i, acc_lo_k0); + _mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, acc_hi_k0); + _mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, acc_lo_k1); + _mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, acc_hi_k1); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[target_feature(enable = "avx")] +pub fn i64_extract_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) { + use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256}; + + unsafe { + let mut src_ptr: *const __m256i = src.as_ptr().add(offset + (blk << 3)) as *const __m256i; // src + 8*blk + let mut dst_ptr: *mut __m256i = dst.as_mut_ptr() as *mut __m256i; + + let step: usize = n >> 2; + + // Each iteration copies 8 i64; advance src by n i64 each row + for _ in 0..rows { + let v: __m256i = _mm256_loadu_si256(src_ptr); + _mm256_storeu_si256(dst_ptr, v); + let v: __m256i = _mm256_loadu_si256(src_ptr.add(1)); + _mm256_storeu_si256(dst_ptr.add(1), v); + dst_ptr = dst_ptr.add(2); + src_ptr = src_ptr.add(step); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[target_feature(enable = "avx")] +pub fn i64_save_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) { + use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256}; + + unsafe { + let mut src_ptr: *const __m256i = src.as_ptr() as *const __m256i; + let mut dst_ptr: *mut __m256i = dst.as_mut_ptr().add(offset + (blk << 3)) as *mut __m256i; // dst + 8*blk + + let step: usize = n >> 2; + + // Each iteration copies 8 i64; advance dst by n i64 each row + for _ in 0..rows { + let v: __m256i = _mm256_loadu_si256(src_ptr); + _mm256_storeu_si256(dst_ptr, v); + let v: __m256i = _mm256_loadu_si256(src_ptr.add(1)); + _mm256_storeu_si256(dst_ptr.add(1), v); + dst_ptr = dst_ptr.add(step); + src_ptr = src_ptr.add(2); + } + } +} diff --git a/poulpy-cpu-avx/src/lib.rs b/poulpy-cpu-avx/src/lib.rs index 9a139f6..9dde9b8 100644 --- a/poulpy-cpu-avx/src/lib.rs +++ b/poulpy-cpu-avx/src/lib.rs @@ -1,7 +1,7 @@ // ───────────────────────────────────────────────────────────── // Build the backend **only when ALL conditions are satisfied** // ───────────────────────────────────────────────────────────── -#![cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +//#![cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] // If the user enables this backend but targets a non-x86_64 CPU → abort #[cfg(all(feature = "enable-avx", not(target_arch = "x86_64")))] @@ -15,6 +15,7 @@ compile_error!("feature `enable-avx` requires AVX2. Build with RUSTFLAGS=\"-C ta #[cfg(all(feature = "enable-avx", target_arch = "x86_64", not(target_feature = "fma")))] compile_error!("feature `enable-avx` requires FMA. Build with RUSTFLAGS=\"-C target-feature=+fma\"."); +mod convolution; mod module; mod reim; mod reim4; diff --git a/poulpy-cpu-avx/src/module.rs b/poulpy-cpu-avx/src/module.rs index 9495257..7f0a238 100644 --- a/poulpy-cpu-avx/src/module.rs +++ b/poulpy-cpu-avx/src/module.rs @@ -5,13 +5,18 @@ use poulpy_hal::{ oep::ModuleNewImpl, reference::{ fft64::{ + convolution::{ + I64ConvolutionByConst1Coeff, I64ConvolutionByConst2Coeffs, I64Extract1BlkContiguous, I64Save1BlkContiguous, + }, reim::{ ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul, ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx, ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref, }, reim4::{ - Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks, + Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4ConvolutionByRealConst1Coeff, + Reim4ConvolutionByRealConst2Coeffs, Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, + Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save1BlkContiguous, Reim4Save2Blks, }, }, znx::{ @@ -26,6 +31,10 @@ use poulpy_hal::{ use crate::{ FFT64Avx, + convolution::{ + i64_convolution_by_const_1coeff_avx, i64_convolution_by_real_const_2coeffs_avx, i64_extract_1blk_contiguous_avx, + i64_save_1blk_contiguous_avx, + }, reim::{ ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma, reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, reim_sub_avx2_fma, @@ -33,8 +42,10 @@ use crate::{ }, reim_to_znx_i64_bnd63_avx2_fma, reim4::{ - reim4_extract_1blk_from_reim_avx, reim4_save_1blk_to_reim_avx, reim4_save_2blk_to_reim_avx, - reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx, + reim4_convolution_1coeff_avx, reim4_convolution_2coeffs_avx, reim4_convolution_by_real_const_1coeff_avx, + reim4_convolution_by_real_const_2coeffs_avx, reim4_extract_1blk_from_reim_contiguous_avx, reim4_save_1blk_to_reim_avx, + reim4_save_1blk_to_reim_contiguous_avx, reim4_save_2blk_to_reim_avx, reim4_vec_mat1col_product_avx, + reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx, }, znx_avx::{ znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_extract_digit_addmul_avx, znx_mul_add_power_of_two_avx, @@ -470,11 +481,55 @@ impl ReimZero for FFT64Avx { } } -impl Reim4Extract1Blk for FFT64Avx { +impl Reim4Convolution1Coeff for FFT64Avx { #[inline(always)] - fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) { unsafe { - reim4_extract_1blk_from_reim_avx(m, rows, blk, dst, src); + reim4_convolution_1coeff_avx(k, dst, a, a_size, b, b_size); + } + } +} + +impl Reim4Convolution2Coeffs for FFT64Avx { + #[inline(always)] + fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) { + unsafe { + reim4_convolution_2coeffs_avx(k, dst, a, a_size, b, b_size); + } + } +} + +impl Reim4ConvolutionByRealConst1Coeff for FFT64Avx { + #[inline(always)] + fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) { + unsafe { + reim4_convolution_by_real_const_1coeff_avx(k, dst, a, a_size, b); + } + } +} + +impl Reim4ConvolutionByRealConst2Coeffs for FFT64Avx { + #[inline(always)] + fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) { + unsafe { + reim4_convolution_by_real_const_2coeffs_avx(k, dst, a, a_size, b); + } + } +} + +impl Reim4Extract1BlkContiguous for FFT64Avx { + #[inline(always)] + fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + unsafe { + reim4_extract_1blk_from_reim_contiguous_avx(m, rows, blk, dst, src); + } + } +} + +impl Reim4Save1BlkContiguous for FFT64Avx { + fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + unsafe { + reim4_save_1blk_to_reim_contiguous_avx(m, rows, blk, dst, src); } } } @@ -523,3 +578,39 @@ impl Reim4Mat2Cols2ndColProd for FFT64Avx { } } } + +impl I64ConvolutionByConst1Coeff for FFT64Avx { + #[inline(always)] + fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) { + unsafe { + i64_convolution_by_const_1coeff_avx(k, dst, a, a_size, b); + } + } +} + +impl I64ConvolutionByConst2Coeffs for FFT64Avx { + #[inline(always)] + fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) { + unsafe { + i64_convolution_by_real_const_2coeffs_avx(k, dst, a, a_size, b); + } + } +} + +impl I64Save1BlkContiguous for FFT64Avx { + #[inline(always)] + fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) { + unsafe { + i64_save_1blk_contiguous_avx(n, offset, rows, blk, dst, src); + } + } +} + +impl I64Extract1BlkContiguous for FFT64Avx { + #[inline(always)] + fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) { + unsafe { + i64_extract_1blk_contiguous_avx(n, offset, rows, blk, dst, src); + } + } +} diff --git a/poulpy-cpu-avx/src/reim/fft_avx2_fma.rs b/poulpy-cpu-avx/src/reim/fft_avx2_fma.rs index ce2446f..f6d0949 100644 --- a/poulpy-cpu-avx/src/reim/fft_avx2_fma.rs +++ b/poulpy-cpu-avx/src/reim/fft_avx2_fma.rs @@ -18,11 +18,7 @@ pub(crate) fn fft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) { let (re, im) = data.split_at_mut(m); if m == 16 { - fft16_avx2_fma( - as_arr_mut::<16, f64>(re), - as_arr_mut::<16, f64>(im), - as_arr::<16, f64>(omg), - ) + fft16_avx2_fma(as_arr_mut::<16, f64>(re), as_arr_mut::<16, f64>(im), as_arr::<16, f64>(omg)) } else if m <= 2048 { fft_bfs_16_avx2_fma(m, re, im, omg, 0); } else { @@ -70,12 +66,7 @@ fn fft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mu while mm > 16 { let h: usize = mm >> 2; for off in (0..m).step_by(mm) { - bitwiddle_fft_avx2_fma( - h, - &mut re[off..], - &mut im[off..], - as_arr::<4, f64>(&omg[pos..]), - ); + bitwiddle_fft_avx2_fma(h, &mut re[off..], &mut im[off..], as_arr::<4, f64>(&omg[pos..])); pos += 4; } @@ -232,16 +223,10 @@ fn test_fft_avx2_fma() { let mut values_0: Vec = vec![0f64; m << 1]; let scale: f64 = 1.0f64 / m as f64; - values_0 - .iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + values_0.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale); let mut values_1: Vec = vec![0f64; m << 1]; - values_1 - .iter_mut() - .zip(values_0.iter()) - .for_each(|(y, x)| *y = *x); + values_1.iter_mut().zip(values_0.iter()).for_each(|(y, x)| *y = *x); ReimFFTAvx::reim_dft_execute(&table, &mut values_0); ReimFFTRef::reim_dft_execute(&table, &mut values_1); @@ -250,14 +235,7 @@ fn test_fft_avx2_fma() { for i in 0..m * 2 { let diff: f64 = (values_0[i] - values_1[i]).abs(); - assert!( - diff <= max_diff, - "{} -> {}-{} = {}", - i, - values_0[i], - values_1[i], - diff - ) + assert!(diff <= max_diff, "{} -> {}-{} = {}", i, values_0[i], values_1[i], diff) } } diff --git a/poulpy-cpu-avx/src/reim/ifft_avx2_fma.rs b/poulpy-cpu-avx/src/reim/ifft_avx2_fma.rs index 776396c..726d683 100644 --- a/poulpy-cpu-avx/src/reim/ifft_avx2_fma.rs +++ b/poulpy-cpu-avx/src/reim/ifft_avx2_fma.rs @@ -17,11 +17,7 @@ pub(crate) fn ifft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) { let (re, im) = data.split_at_mut(m); if m == 16 { - ifft16_avx2_fma( - as_arr_mut::<16, f64>(re), - as_arr_mut::<16, f64>(im), - as_arr::<16, f64>(omg), - ) + ifft16_avx2_fma(as_arr_mut::<16, f64>(re), as_arr_mut::<16, f64>(im), as_arr::<16, f64>(omg)) } else if m <= 2048 { ifft_bfs_16_avx2_fma(m, re, im, omg, 0); } else { @@ -72,12 +68,7 @@ fn ifft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], m while h < m_half { let mm: usize = h << 2; for off in (0..m).step_by(mm) { - inv_bitwiddle_ifft_avx2_fma( - h, - &mut re[off..], - &mut im[off..], - as_arr::<4, f64>(&omg[pos..]), - ); + inv_bitwiddle_ifft_avx2_fma(h, &mut re[off..], &mut im[off..], as_arr::<4, f64>(&omg[pos..])); pos += 4; } h = mm; @@ -225,16 +216,10 @@ fn test_ifft_avx2_fma() { let mut values_0: Vec = vec![0f64; m << 1]; let scale: f64 = 1.0f64 / m as f64; - values_0 - .iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + values_0.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale); let mut values_1: Vec = vec![0f64; m << 1]; - values_1 - .iter_mut() - .zip(values_0.iter()) - .for_each(|(y, x)| *y = *x); + values_1.iter_mut().zip(values_0.iter()).for_each(|(y, x)| *y = *x); ReimIFFTAvx::reim_dft_execute(&table, &mut values_0); ReimIFFTRef::reim_dft_execute(&table, &mut values_1); @@ -243,14 +228,7 @@ fn test_ifft_avx2_fma() { for i in 0..m * 2 { let diff: f64 = (values_0[i] - values_1[i]).abs(); - assert!( - diff <= max_diff, - "{} -> {}-{} = {}", - i, - values_0[i], - values_1[i], - diff - ) + assert!(diff <= max_diff, "{} -> {}-{} = {}", i, values_0[i], values_1[i], diff) } } diff --git a/poulpy-cpu-avx/src/reim/mod.rs b/poulpy-cpu-avx/src/reim/mod.rs index ae2d67a..e488f56 100644 --- a/poulpy-cpu-avx/src/reim/mod.rs +++ b/poulpy-cpu-avx/src/reim/mod.rs @@ -32,10 +32,7 @@ use rand_distr::num_traits::{Float, FloatConst}; use crate::reim::{fft_avx2_fma::fft_avx2_fma, ifft_avx2_fma::ifft_avx2_fma}; -global_asm!( - include_str!("fft16_avx2_fma.s"), - include_str!("ifft16_avx2_fma.s") -); +global_asm!(include_str!("fft16_avx2_fma.s"), include_str!("ifft16_avx2_fma.s")); #[inline(always)] pub(crate) fn as_arr(x: &[R]) -> &[R; SIZE] { diff --git a/poulpy-cpu-avx/src/reim4/arithmetic_avx.rs b/poulpy-cpu-avx/src/reim4/arithmetic_avx.rs index 8882794..f46ced8 100644 --- a/poulpy-cpu-avx/src/reim4/arithmetic_avx.rs +++ b/poulpy-cpu-avx/src/reim4/arithmetic_avx.rs @@ -1,7 +1,7 @@ /// # Safety /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); #[target_feature(enable = "avx")] -pub fn reim4_extract_1blk_from_reim_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { +pub fn reim4_extract_1blk_from_reim_contiguous_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { use core::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd}; unsafe { @@ -20,6 +20,28 @@ pub fn reim4_extract_1blk_from_reim_avx(m: usize, rows: usize, blk: usize, dst: } } +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[target_feature(enable = "avx")] +pub fn reim4_save_1blk_to_reim_contiguous_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + use core::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd}; + + unsafe { + let mut src_ptr: *const __m256d = src.as_ptr() as *const __m256d; + let mut dst_ptr: *mut __m256d = dst.as_mut_ptr().add(blk << 2) as *mut __m256d; // dst + 4*blk + + let step: usize = m >> 2; + + // Each iteration copies 4 doubles; advance dst by m doubles each row + for _ in 0..2 * rows { + let v: __m256d = _mm256_loadu_pd(src_ptr as *const f64); + _mm256_storeu_pd(dst_ptr as *mut f64, v); + dst_ptr = dst_ptr.add(step); + src_ptr = src_ptr.add(1); + } + } +} + /// # Safety /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); #[target_feature(enable = "avx2,fma")] @@ -148,11 +170,7 @@ pub fn reim4_vec_mat2cols_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], #[cfg(debug_assertions)] { - assert!( - dst.len() >= 8, - "dst must be at least 8 doubles but is {}", - dst.len() - ); + assert!(dst.len() >= 8, "dst must be at least 8 doubles but is {}", dst.len()); assert!( u.len() >= nrows * 8, "u must be at least nrows={} * 8 doubles but is {}", @@ -185,16 +203,16 @@ pub fn reim4_vec_mat2cols_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], let br: __m256d = _mm256_loadu_pd(v_ptr.add(8)); let bi: __m256d = _mm256_loadu_pd(v_ptr.add(12)); - // re1 = re1 - ui*ai; re2 = re2 - ui*bi; + // re1 = ui*ai - re1; re2 = ui*bi - re2; re1 = _mm256_fmsub_pd(ui, ai, re1); re2 = _mm256_fmsub_pd(ui, bi, re2); - // im1 = im1 + ur*ai; im2 = im2 + ur*bi; + // im1 = ur*ai + im1; im2 = ur*bi + im2; im1 = _mm256_fmadd_pd(ur, ai, im1); im2 = _mm256_fmadd_pd(ur, bi, im2); - // re1 = re1 - ur*ar; re2 = re2 - ur*br; + // re1 = ur*ar - re1; re2 = ur*br - re2; re1 = _mm256_fmsub_pd(ur, ar, re1); re2 = _mm256_fmsub_pd(ur, br, re2); - // im1 = im1 + ui*ar; im2 = im2 + ui*br; + // im1 = ui*ar + im1; im2 = ui*br + im2; im1 = _mm256_fmadd_pd(ui, ar, im1); im2 = _mm256_fmadd_pd(ui, br, im2); @@ -219,10 +237,7 @@ pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: & { assert_eq!(dst.len(), 16, "dst must have 16 doubles"); assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles"); - assert!( - v.len() >= nrows * 16, - "v must be at least nrows * 16 doubles" - ); + assert!(v.len() >= nrows * 16, "v must be at least nrows * 16 doubles"); } unsafe { @@ -239,13 +254,13 @@ pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: & let ar: __m256d = _mm256_loadu_pd(v_ptr); let ai: __m256d = _mm256_loadu_pd(v_ptr.add(4)); - // re1 = re1 - ui*ai; re2 = re2 - ui*bi; + // re1 = ui*ai - re1; re1 = _mm256_fmsub_pd(ui, ai, re1); - // im1 = im1 + ur*ai; im2 = im2 + ur*bi; + // im1 = im1 + ur*ai; im1 = _mm256_fmadd_pd(ur, ai, im1); - // re1 = re1 - ur*ar; re2 = re2 - ur*br; + // re1 = ur*ar - re1; re1 = _mm256_fmsub_pd(ur, ar, re1); - // im1 = im1 + ui*ar; im2 = im2 + ui*br; + // im1 = im1 + ui*ar; im1 = _mm256_fmadd_pd(ui, ar, im1); u_ptr = u_ptr.add(8); @@ -256,3 +271,360 @@ pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: & _mm256_storeu_pd(dst.as_mut_ptr().add(4), im1); } } + +/// # Safety +/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`). +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn reim4_convolution_1coeff_avx(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) { + use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd}; + + unsafe { + // Scalar guard — same semantics as reference implementation + if k >= a_size + b_size { + let zero: __m256d = _mm256_setzero_pd(); + let dst_ptr: *mut f64 = dst.as_mut_ptr(); + _mm256_storeu_pd(dst_ptr, zero); + _mm256_storeu_pd(dst_ptr.add(4), zero); + return; + } + + let j_min: usize = k.saturating_sub(a_size - 1); + let j_max: usize = (k + 1).min(b_size); + + // acc_re = dst[0..4], acc_im = dst[4..8] + let mut acc_re: __m256d = _mm256_setzero_pd(); + let mut acc_im: __m256d = _mm256_setzero_pd(); + + let mut a_ptr: *const f64 = a.as_ptr().add(8 * (k - j_min)); + let mut b_ptr: *const f64 = b.as_ptr().add(8 * j_min); + + for _ in 0..j_max - j_min { + // Load a[(k - j)] + let ar: __m256d = _mm256_loadu_pd(a_ptr); + let ai: __m256d = _mm256_loadu_pd(a_ptr.add(4)); + + // Load b[j] + let br: __m256d = _mm256_loadu_pd(b_ptr); + let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4)); + + // acc_re = ai*bi - acc_re + acc_re = _mm256_fmsub_pd(ai, bi, acc_re); + // acc_im = ar*bi - acc_im + acc_im = _mm256_fmadd_pd(ar, bi, acc_im); + // acc_re = ar*br - acc_re + acc_re = _mm256_fmsub_pd(ar, br, acc_re); + // acc_im = acc_im + ai*br + acc_im = _mm256_fmadd_pd(ai, br, acc_im); + + a_ptr = a_ptr.sub(8); + b_ptr = b_ptr.add(8); + } + + // Store accumulators into dst + _mm256_storeu_pd(dst.as_mut_ptr(), acc_re); + _mm256_storeu_pd(dst.as_mut_ptr().add(4), acc_im); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`). +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn reim4_convolution_2coeffs_avx(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) { + use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fnmadd_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd}; + + debug_assert!(a.len() >= 8 * a_size); + debug_assert!(b.len() >= 8 * b_size); + + let k0: usize = k; + let k1: usize = k + 1; + let bound: usize = a_size + b_size; + + // Since k is a multiple of two, if either k0 or k1 are out of range, + // both are. + if k0 >= bound { + unsafe { + let zero: __m256d = _mm256_setzero_pd(); + let dst_ptr: *mut f64 = dst.as_mut_ptr(); + _mm256_storeu_pd(dst_ptr, zero); + _mm256_storeu_pd(dst_ptr.add(4), zero); + _mm256_storeu_pd(dst_ptr.add(8), zero); + _mm256_storeu_pd(dst_ptr.add(12), zero); + } + return; + } + + unsafe { + let mut acc_re_k0: __m256d = _mm256_setzero_pd(); + let mut acc_im_k0: __m256d = _mm256_setzero_pd(); + let mut acc_re_k1: __m256d = _mm256_setzero_pd(); + let mut acc_im_k1: __m256d = _mm256_setzero_pd(); + + let j0_min: usize = (k0 + 1).saturating_sub(a_size); + let j0_max: usize = (k0 + 1).min(b_size); + + if k1 >= bound { + let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min)); + let mut b_ptr: *const f64 = b.as_ptr().add(8 * j0_min); + + // Region 1: contributions to k0 only, j ∈ [j0_min, j1_min) + for _ in 0..j0_max - j0_min { + let ar: __m256d = _mm256_loadu_pd(a_k0_ptr); + let ai: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4)); + let br: __m256d = _mm256_loadu_pd(b_ptr); + let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4)); + + acc_re_k0 = _mm256_fmadd_pd(ar, br, acc_re_k0); + acc_re_k0 = _mm256_fnmadd_pd(ai, bi, acc_re_k0); + acc_im_k0 = _mm256_fmadd_pd(ar, bi, acc_im_k0); + acc_im_k0 = _mm256_fmadd_pd(ai, br, acc_im_k0); + + a_k0_ptr = a_k0_ptr.sub(8); + b_ptr = b_ptr.add(8); + } + } else { + let j1_min: usize = (k1 + 1).saturating_sub(a_size); + let j1_max: usize = (k1 + 1).min(b_size); + + let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min)); + let mut a_k1_ptr: *const f64 = a.as_ptr().add(8 * (k1 - j1_min)); + let mut b_ptr: *const f64 = b.as_ptr().add(8 * j0_min); + + // Region 1: contributions to k0 only, j ∈ [j0_min, j1_min) + for _ in 0..j1_min - j0_min { + let ar: __m256d = _mm256_loadu_pd(a_k0_ptr); + let ai: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4)); + let br: __m256d = _mm256_loadu_pd(b_ptr); + let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4)); + + acc_re_k0 = _mm256_fmadd_pd(ar, br, acc_re_k0); + acc_re_k0 = _mm256_fnmadd_pd(ai, bi, acc_re_k0); + acc_im_k0 = _mm256_fmadd_pd(ar, bi, acc_im_k0); + acc_im_k0 = _mm256_fmadd_pd(ai, br, acc_im_k0); + + a_k0_ptr = a_k0_ptr.sub(8); + b_ptr = b_ptr.add(8); + } + + // Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max) + // We can save one load on b. + for _ in 0..j0_max - j1_min { + let ar0: __m256d = _mm256_loadu_pd(a_k0_ptr); + let ai0: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4)); + let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr); + let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4)); + let br: __m256d = _mm256_loadu_pd(b_ptr); + let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4)); + + // k0 + acc_re_k0 = _mm256_fmadd_pd(ar0, br, acc_re_k0); + acc_re_k0 = _mm256_fnmadd_pd(ai0, bi, acc_re_k0); + acc_im_k0 = _mm256_fmadd_pd(ar0, bi, acc_im_k0); + acc_im_k0 = _mm256_fmadd_pd(ai0, br, acc_im_k0); + + // k1 + acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1); + acc_re_k1 = _mm256_fnmadd_pd(ai1, bi, acc_re_k1); + acc_im_k1 = _mm256_fmadd_pd(ar1, bi, acc_im_k1); + acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1); + + a_k0_ptr = a_k0_ptr.sub(8); + a_k1_ptr = a_k1_ptr.sub(8); + b_ptr = b_ptr.add(8); + } + + // Region 3: contributions to k1 only, j ∈ [j0_max, j1_max) + for _ in 0..j1_max - j0_max { + let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr); + let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4)); + let br: __m256d = _mm256_loadu_pd(b_ptr); + let bi: __m256d = _mm256_loadu_pd(b_ptr.add(4)); + + acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1); + acc_re_k1 = _mm256_fnmadd_pd(ai1, bi, acc_re_k1); + acc_im_k1 = _mm256_fmadd_pd(ar1, bi, acc_im_k1); + acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1); + + a_k1_ptr = a_k1_ptr.sub(8); + b_ptr = b_ptr.add(8); + } + } + + // Store both coefficients + let dst_ptr = dst.as_mut_ptr(); + _mm256_storeu_pd(dst_ptr, acc_re_k0); + _mm256_storeu_pd(dst_ptr.add(4), acc_im_k0); + _mm256_storeu_pd(dst_ptr.add(8), acc_re_k1); + _mm256_storeu_pd(dst_ptr.add(12), acc_im_k1); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`). +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn reim4_convolution_by_real_const_1coeff_avx(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) { + use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_set1_pd, _mm256_setzero_pd, _mm256_storeu_pd}; + + unsafe { + let b_size: usize = b.len(); + + if k >= a_size + b_size { + let zero: __m256d = _mm256_setzero_pd(); + let dst_ptr: *mut f64 = dst.as_mut_ptr(); + _mm256_storeu_pd(dst_ptr, zero); + _mm256_storeu_pd(dst_ptr.add(4), zero); + return; + } + + let j_min: usize = k.saturating_sub(a_size - 1); + let j_max: usize = (k + 1).min(b_size); + + // acc_re = dst[0..4], acc_im = dst[4..8] + let mut acc_re: __m256d = _mm256_setzero_pd(); + let mut acc_im: __m256d = _mm256_setzero_pd(); + + let mut a_ptr: *const f64 = a.as_ptr().add(8 * (k - j_min)); + let mut b_ptr: *const f64 = b.as_ptr().add(j_min); + + for _ in 0..j_max - j_min { + // Load a[(k - j)] + let ar: __m256d = _mm256_loadu_pd(a_ptr); + let ai: __m256d = _mm256_loadu_pd(a_ptr.add(4)); + + // Load scalar b[j] and broadcast + let br: __m256d = _mm256_set1_pd(*b_ptr); + + // Complex * real: + // re += ar * br + // im += ai * br + acc_re = _mm256_fmadd_pd(ar, br, acc_re); + acc_im = _mm256_fmadd_pd(ai, br, acc_im); + + a_ptr = a_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + + // Store accumulators into dst + _mm256_storeu_pd(dst.as_mut_ptr(), acc_re); + _mm256_storeu_pd(dst.as_mut_ptr().add(4), acc_im); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 and FMA (e.g. `is_x86_feature_detected!("avx2")`). +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn reim4_convolution_by_real_const_2coeffs_avx(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) { + use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_set1_pd, _mm256_setzero_pd, _mm256_storeu_pd}; + + let b_size: usize = b.len(); + + debug_assert!(a.len() >= 8 * a_size); + + let k0: usize = k; + let k1: usize = k + 1; + let bound: usize = a_size + b_size; + + // Since k is a multiple of two, if either k0 or k1 are out of range, + // both are. + if k0 >= bound { + unsafe { + let zero: __m256d = _mm256_setzero_pd(); + let dst_ptr: *mut f64 = dst.as_mut_ptr(); + _mm256_storeu_pd(dst_ptr, zero); + _mm256_storeu_pd(dst_ptr.add(4), zero); + _mm256_storeu_pd(dst_ptr.add(8), zero); + _mm256_storeu_pd(dst_ptr.add(12), zero); + } + return; + } + + unsafe { + let mut acc_re_k0: __m256d = _mm256_setzero_pd(); + let mut acc_im_k0: __m256d = _mm256_setzero_pd(); + let mut acc_re_k1: __m256d = _mm256_setzero_pd(); + let mut acc_im_k1: __m256d = _mm256_setzero_pd(); + + let j0_min: usize = (k0 + 1).saturating_sub(a_size); + let j0_max: usize = (k0 + 1).min(b_size); + + if k1 >= bound { + let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min)); + let mut b_ptr: *const f64 = b.as_ptr().add(j0_min); + + // Contributions to k0 only + for _ in 0..j0_max - j0_min { + let ar: __m256d = _mm256_loadu_pd(a_k0_ptr); + let ai: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4)); + let br: __m256d = _mm256_set1_pd(*b_ptr); + + // complex * real + acc_re_k0 = _mm256_fmadd_pd(ar, br, acc_re_k0); + acc_im_k0 = _mm256_fmadd_pd(ai, br, acc_im_k0); + + a_k0_ptr = a_k0_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + } else { + let j1_min: usize = (k1 + 1).saturating_sub(a_size); + let j1_max: usize = (k1 + 1).min(b_size); + + let mut a_k0_ptr: *const f64 = a.as_ptr().add(8 * (k0 - j0_min)); + let mut a_k1_ptr: *const f64 = a.as_ptr().add(8 * (k1 - j1_min)); + let mut b_ptr: *const f64 = b.as_ptr().add(j0_min); + + // Region 1: k0 only, j ∈ [j0_min, j1_min) + for _ in 0..j1_min - j0_min { + let ar0: __m256d = _mm256_loadu_pd(a_k0_ptr); + let ai0: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4)); + let br: __m256d = _mm256_set1_pd(*b_ptr); + + acc_re_k0 = _mm256_fmadd_pd(ar0, br, acc_re_k0); + acc_im_k0 = _mm256_fmadd_pd(ai0, br, acc_im_k0); + + a_k0_ptr = a_k0_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + + // Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max) + // Still “save one load on b”: we broadcast once and reuse. + for _ in 0..j0_max - j1_min { + let ar0: __m256d = _mm256_loadu_pd(a_k0_ptr); + let ai0: __m256d = _mm256_loadu_pd(a_k0_ptr.add(4)); + let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr); + let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4)); + let br: __m256d = _mm256_set1_pd(*b_ptr); + + // k0 + acc_re_k0 = _mm256_fmadd_pd(ar0, br, acc_re_k0); + acc_im_k0 = _mm256_fmadd_pd(ai0, br, acc_im_k0); + + // k1 + acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1); + acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1); + + a_k0_ptr = a_k0_ptr.sub(8); + a_k1_ptr = a_k1_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + + // Region 3: k1 only, j ∈ [j0_max, j1_max) + for _ in 0..j1_max - j0_max { + let ar1: __m256d = _mm256_loadu_pd(a_k1_ptr); + let ai1: __m256d = _mm256_loadu_pd(a_k1_ptr.add(4)); + let br: __m256d = _mm256_set1_pd(*b_ptr); + + acc_re_k1 = _mm256_fmadd_pd(ar1, br, acc_re_k1); + acc_im_k1 = _mm256_fmadd_pd(ai1, br, acc_im_k1); + + a_k1_ptr = a_k1_ptr.sub(8); + b_ptr = b_ptr.add(1); + } + } + + // Store both coefficients + let dst_ptr = dst.as_mut_ptr(); + _mm256_storeu_pd(dst_ptr, acc_re_k0); + _mm256_storeu_pd(dst_ptr.add(4), acc_im_k0); + _mm256_storeu_pd(dst_ptr.add(8), acc_re_k1); + _mm256_storeu_pd(dst_ptr.add(12), acc_im_k1); + } +} diff --git a/poulpy-cpu-avx/src/tests.rs b/poulpy-cpu-avx/src/tests.rs index 99cefa6..9487a6d 100644 --- a/poulpy-cpu-avx/src/tests.rs +++ b/poulpy-cpu-avx/src/tests.rs @@ -1,4 +1,8 @@ -use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring}; +use poulpy_hal::{ + api::ModuleNew, + layouts::Module, + test_suite::convolution::{test_convolution, test_convolution_by_const, test_convolution_pairwise}, +}; use crate::FFT64Avx; @@ -119,7 +123,19 @@ mod poulpy_cpu_avx { } #[test] -fn test_convolution_fft64_avx() { - let module: Module = Module::::new(64); - test_bivariate_tensoring(&module); +fn test_convolution_by_const_fft64_avx() { + let module: Module = Module::::new(8); + test_convolution_by_const(&module); +} + +#[test] +fn test_convolution_fft64_avx() { + let module: Module = Module::::new(8); + test_convolution(&module); +} + +#[test] +fn test_convolution_pairwise_fft64_avx() { + let module: Module = Module::::new(8); + test_convolution_pairwise(&module); } diff --git a/poulpy-cpu-avx/src/vec_znx.rs b/poulpy-cpu-avx/src/vec_znx.rs index e42b595..1fad860 100644 --- a/poulpy-cpu-avx/src/vec_znx.rs +++ b/poulpy-cpu-avx/src/vec_znx.rs @@ -53,11 +53,12 @@ where { fn vec_znx_normalize_impl( module: &Module, - res_base2k: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_base2k: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where @@ -65,7 +66,7 @@ where A: VecZnxToRef, { let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_normalize::(res_base2k, res, res_col, a_base2k, a, a_col, carry); + vec_znx_normalize::(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry); } } diff --git a/poulpy-cpu-avx/src/vec_znx_big.rs b/poulpy-cpu-avx/src/vec_znx_big.rs index e1e4580..72dbb79 100644 --- a/poulpy-cpu-avx/src/vec_znx_big.rs +++ b/poulpy-cpu-avx/src/vec_znx_big.rs @@ -26,7 +26,7 @@ use poulpy_hal::{ source::Source, }; -unsafe impl VecZnxBigAllocBytesImpl for FFT64Avx { +unsafe impl VecZnxBigAllocBytesImpl for FFT64Avx { fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_big_word_count() * n * cols * size * size_of::() } @@ -280,11 +280,12 @@ where { fn vec_znx_big_normalize_impl( module: &Module, - res_basek: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_basek: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where @@ -292,7 +293,7 @@ where A: VecZnxBigToRef, { let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); - vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry); + vec_znx_big_normalize(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry); } } diff --git a/poulpy-cpu-avx/src/znx_avx/automorphism.rs b/poulpy-cpu-avx/src/znx_avx/automorphism.rs index bfeff3c..ab921e3 100644 --- a/poulpy-cpu-avx/src/znx_avx/automorphism.rs +++ b/poulpy-cpu-avx/src/znx_avx/automorphism.rs @@ -53,12 +53,8 @@ pub fn znx_automorphism_avx(p: i64, res: &mut [i64], a: &[i64]) { let mask_1n_vec: __m256i = _mm256_set1_epi64x(mask_1n as i64); // Lane offsets [0, inv, 2*inv, 3*inv] (mod 2n) - let lane_offsets: __m256i = _mm256_set_epi64x( - ((inv * 3) & mask_2n) as i64, - ((inv * 2) & mask_2n) as i64, - inv as i64, - 0i64, - ); + let lane_offsets: __m256i = + _mm256_set_epi64x(((inv * 3) & mask_2n) as i64, ((inv * 2) & mask_2n) as i64, inv as i64, 0i64); // t_base = (j * inv) mod 2n. let mut t_base: usize = 0; diff --git a/poulpy-cpu-avx/src/znx_avx/normalization.rs b/poulpy-cpu-avx/src/znx_avx/normalization.rs index 9a53d02..16a3d01 100644 --- a/poulpy-cpu-avx/src/znx_avx/normalization.rs +++ b/poulpy-cpu-avx/src/znx_avx/normalization.rs @@ -82,14 +82,14 @@ pub fn znx_extract_digit_addmul_avx(base2k: usize, lsh: usize, res: &mut [i64], let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i; // constants for digit/carry extraction - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); + let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); for _ in 0..span { // load source & extract digit/carry let sv: __m256i = _mm256_loadu_si256(ss); let digit_256: __m256i = get_digit_avx(sv, mask, sign); - let carry_256: __m256i = get_carry_avx(sv, digit_256, basek_vec, top_mask); + let carry_256: __m256i = get_carry_avx(sv, digit_256, base2k_vec, top_mask); // res += (digit << lsh) let rv: __m256i = _mm256_loadu_si256(rr); @@ -135,7 +135,7 @@ pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64]) let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i; // Constants for digit/carry extraction - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); + let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k); for _ in 0..span { // Load res lane @@ -143,7 +143,7 @@ pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64]) // Extract digit and carry from res let digit_256: __m256i = get_digit_avx(rv, mask, sign); - let carry_256: __m256i = get_carry_avx(rv, digit_256, basek_vec, top_mask); + let carry_256: __m256i = get_carry_avx(rv, digit_256, base2k_vec, top_mask); // src += carry let sv: __m256i = _mm256_loadu_si256(ss); @@ -187,7 +187,7 @@ pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i let mut xx: *const __m256i = x.as_ptr() as *const __m256i; let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; - let (mask, sign, basek_vec, top_mask) = if lsh == 0 { + let (mask, sign, base2k_vec, top_mask) = if lsh == 0 { normalize_consts_avx(base2k) } else { normalize_consts_avx(base2k - lsh) @@ -200,7 +200,7 @@ pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i let digit_256: __m256i = get_digit_avx(xv, mask, sign); // (x - digit) >> base2k - let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask); + let carry_256: __m256i = get_carry_avx(xv, digit_256, base2k_vec, top_mask); _mm256_storeu_si256(cc, carry_256); @@ -239,7 +239,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [ let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; if lsh == 0 { - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); + let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k); for _ in 0..span { let xv: __m256i = _mm256_loadu_si256(xx); @@ -248,7 +248,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [ let digit_256: __m256i = get_digit_avx(xv, mask, sign); // (x - digit) >> base2k - let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask); + let carry_256: __m256i = get_carry_avx(xv, digit_256, base2k_vec, top_mask); _mm256_storeu_si256(xx, digit_256); _mm256_storeu_si256(cc, carry_256); @@ -257,7 +257,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [ cc = cc.add(1); } } else { - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh); + let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); @@ -268,7 +268,7 @@ pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [ let digit_256: __m256i = get_digit_avx(xv, mask, sign); // (x - digit) >> base2k - let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask); + let carry_256: __m256i = get_carry_avx(xv, digit_256, base2k_vec, top_mask); _mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v)); _mm256_storeu_si256(cc, carry_256); @@ -311,7 +311,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; if lsh == 0 { - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); + let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k); for _ in 0..span { let av: __m256i = _mm256_loadu_si256(aa); @@ -320,7 +320,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: let digit_256: __m256i = get_digit_avx(av, mask, sign); // (x - digit) >> base2k - let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask); + let carry_256: __m256i = get_carry_avx(av, digit_256, base2k_vec, top_mask); _mm256_storeu_si256(xx, digit_256); _mm256_storeu_si256(cc, carry_256); @@ -332,7 +332,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh); + let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); @@ -343,7 +343,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: let digit_256: __m256i = get_digit_avx(av, mask, sign); // (x - digit) >> base2k - let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask); + let carry_256: __m256i = get_carry_avx(av, digit_256, base2k_vec, top_mask); _mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v)); _mm256_storeu_si256(cc, carry_256); @@ -359,13 +359,7 @@ pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: if !x.len().is_multiple_of(4) { use poulpy_hal::reference::znx::znx_normalize_first_step_ref; - znx_normalize_first_step_ref( - base2k, - lsh, - &mut x[span << 2..], - &a[span << 2..], - &mut carry[span << 2..], - ); + znx_normalize_first_step_ref(base2k, lsh, &mut x[span << 2..], &a[span << 2..], &mut carry[span << 2..]); } } @@ -386,7 +380,7 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut let span: usize = n >> 2; - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); + let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k); unsafe { let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; @@ -398,11 +392,11 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut let cv: __m256i = _mm256_loadu_si256(cc); let d0: __m256i = get_digit_avx(xv, mask, sign); - let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask); + let c0: __m256i = get_carry_avx(xv, d0, base2k_vec, top_mask); let s: __m256i = _mm256_add_epi64(d0, cv); let x1: __m256i = get_digit_avx(s, mask, sign); - let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); _mm256_storeu_si256(xx, x1); @@ -414,7 +408,7 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh); + let (mask_lsh, sign_lsh, base2k_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); @@ -423,13 +417,13 @@ pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut let cv: __m256i = _mm256_loadu_si256(cc); let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh); - let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh); + let c0: __m256i = get_carry_avx(xv, d0, base2k_vec_lsh, top_mask_lsh); let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); let s: __m256i = _mm256_add_epi64(d0_lsh, cv); let x1: __m256i = get_digit_avx(s, mask, sign); - let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); _mm256_storeu_si256(xx, x1); @@ -465,7 +459,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[ let span: usize = n >> 2; - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); + let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k); unsafe { let mut xx: *const __m256i = x.as_ptr() as *const __m256i; @@ -477,11 +471,11 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[ let cv: __m256i = _mm256_loadu_si256(cc); let d0: __m256i = get_digit_avx(xv, mask, sign); - let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask); + let c0: __m256i = get_carry_avx(xv, d0, base2k_vec, top_mask); let s: __m256i = _mm256_add_epi64(d0, cv); let x1: __m256i = get_digit_avx(s, mask, sign); - let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); _mm256_storeu_si256(cc, cout); @@ -492,7 +486,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[ } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh); + let (mask_lsh, sign_lsh, base2k_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); @@ -501,13 +495,13 @@ pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[ let cv: __m256i = _mm256_loadu_si256(cc); let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh); - let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh); + let c0: __m256i = get_carry_avx(xv, d0, base2k_vec_lsh, top_mask_lsh); let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); let s: __m256i = _mm256_add_epi64(d0_lsh, cv); let x1: __m256i = get_digit_avx(s, mask, sign); - let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); _mm256_storeu_si256(cc, cout); @@ -543,7 +537,7 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a let span: usize = n >> 2; - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); + let (mask, sign, base2k_vec, top_mask) = normalize_consts_avx(base2k); unsafe { let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; @@ -556,11 +550,11 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a let cv: __m256i = _mm256_loadu_si256(cc); let d0: __m256i = get_digit_avx(av, mask, sign); - let c0: __m256i = get_carry_avx(av, d0, basek_vec, top_mask); + let c0: __m256i = get_carry_avx(av, d0, base2k_vec, top_mask); let s: __m256i = _mm256_add_epi64(d0, cv); let x1: __m256i = get_digit_avx(s, mask, sign); - let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); _mm256_storeu_si256(xx, x1); @@ -573,7 +567,7 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh); + let (mask_lsh, sign_lsh, base2k_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); @@ -582,13 +576,13 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a let cv: __m256i = _mm256_loadu_si256(cc); let d0: __m256i = get_digit_avx(av, mask_lsh, sign_lsh); - let c0: __m256i = get_carry_avx(av, d0, basek_vec_lsh, top_mask_lsh); + let c0: __m256i = get_carry_avx(av, d0, base2k_vec_lsh, top_mask_lsh); let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); let s: __m256i = _mm256_add_epi64(d0_lsh, cv); let x1: __m256i = get_digit_avx(s, mask, sign); - let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let c1: __m256i = get_carry_avx(s, x1, base2k_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); _mm256_storeu_si256(xx, x1); @@ -604,13 +598,7 @@ pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a if !x.len().is_multiple_of(4) { use poulpy_hal::reference::znx::znx_normalize_middle_step_ref; - znx_normalize_middle_step_ref( - base2k, - lsh, - &mut x[span << 2..], - &a[span << 2..], - &mut carry[span << 2..], - ); + znx_normalize_middle_step_ref(base2k, lsh, &mut x[span << 2..], &a[span << 2..], &mut carry[span << 2..]); } } @@ -753,13 +741,7 @@ pub fn znx_normalize_final_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: if !x.len().is_multiple_of(4) { use poulpy_hal::reference::znx::znx_normalize_final_step_ref; - znx_normalize_final_step_ref( - base2k, - lsh, - &mut x[span << 2..], - &a[span << 2..], - &mut carry[span << 2..], - ); + znx_normalize_final_step_ref(base2k, lsh, &mut x[span << 2..], &a[span << 2..], &mut carry[span << 2..]); } } @@ -832,8 +814,8 @@ mod tests { unsafe { let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i); let d_256: __m256i = _mm256_loadu_si256(carry.as_ptr() as *const __m256i); - let (_, _, basek_vec, top_mask) = normalize_consts_avx(base2k); - let digit: __m256i = get_carry_avx(x_256, d_256, basek_vec, top_mask); + let (_, _, base2k_vec, top_mask) = normalize_consts_avx(base2k); + let digit: __m256i = get_carry_avx(x_256, d_256, base2k_vec, top_mask); _mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit); } assert_eq!(y0, y1); diff --git a/poulpy-cpu-ref/Cargo.toml b/poulpy-cpu-ref/Cargo.toml index 39c7284..7c88b3b 100644 --- a/poulpy-cpu-ref/Cargo.toml +++ b/poulpy-cpu-ref/Cargo.toml @@ -27,5 +27,5 @@ rustdoc-args = ["--cfg", "docsrs"] [[bench]] -name = "vmp" +name = "convolution" harness = false \ No newline at end of file diff --git a/poulpy-cpu-ref/benches/convolution.rs b/poulpy-cpu-ref/benches/convolution.rs new file mode 100644 index 0000000..65edae4 --- /dev/null +++ b/poulpy-cpu-ref/benches/convolution.rs @@ -0,0 +1,35 @@ +use criterion::{Criterion, criterion_group, criterion_main}; +use poulpy_cpu_ref::FFT64Ref; +use poulpy_hal::bench_suite::convolution::{ + bench_cnv_apply_dft, bench_cnv_by_const_apply, bench_cnv_pairwise_apply_dft, bench_cnv_prepare_left, bench_cnv_prepare_right, +}; + +fn bench_cnv_prepare_left_cpu_ref_fft64(c: &mut Criterion) { + bench_cnv_prepare_left::(c, "cpu_ref::fft64"); +} + +fn bench_cnv_prepare_right_cpu_ref_fft64(c: &mut Criterion) { + bench_cnv_prepare_right::(c, "cpu_ref::fft64"); +} + +fn bench_bench_cnv_apply_dft_cpu_ref_fft64(c: &mut Criterion) { + bench_cnv_apply_dft::(c, "cpu_ref::fft64"); +} + +fn bench_bench_bench_cnv_pairwise_apply_dft_cpu_ref_fft64(c: &mut Criterion) { + bench_cnv_pairwise_apply_dft::(c, "cpu_ref::fft64"); +} + +fn bench_cnv_by_const_apply_cpu_ref_fft64(c: &mut Criterion) { + bench_cnv_by_const_apply::(c, "cpu_ref::fft64"); +} + +criterion_group!( + benches, + bench_cnv_prepare_left_cpu_ref_fft64, + bench_cnv_prepare_right_cpu_ref_fft64, + bench_bench_cnv_apply_dft_cpu_ref_fft64, + bench_bench_bench_cnv_pairwise_apply_dft_cpu_ref_fft64, + bench_cnv_by_const_apply_cpu_ref_fft64, +); +criterion_main!(benches); diff --git a/poulpy-cpu-ref/benches/fft.rs b/poulpy-cpu-ref/benches/fft.rs index 47be0f3..4d552d6 100644 --- a/poulpy-cpu-ref/benches/fft.rs +++ b/poulpy-cpu-ref/benches/fft.rs @@ -11,10 +11,7 @@ pub fn bench_fft_ref(c: &mut Criterion) { fn runner(m: usize) -> impl FnMut() { let mut values: Vec = vec![0f64; m << 1]; let scale: f64 = 1.0f64 / (2 * m) as f64; - values - .iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale); let table: ReimFFTTable = ReimFFTTable::::new(m); move || { ReimFFTRef::reim_dft_execute(&table, &mut values); @@ -39,10 +36,7 @@ pub fn bench_ifft_ref(c: &mut Criterion) { fn runner(m: usize) -> impl FnMut() { let mut values: Vec = vec![0f64; m << 1]; let scale: f64 = 1.0f64 / (2 * m) as f64; - values - .iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale); let table: ReimIFFTTable = ReimIFFTTable::::new(m); move || { ReimIFFTRef::reim_dft_execute(&table, &mut values); diff --git a/poulpy-cpu-ref/benches/vec_znx.rs b/poulpy-cpu-ref/benches/vec_znx.rs index 9c98649..8ad6a60 100644 --- a/poulpy-cpu-ref/benches/vec_znx.rs +++ b/poulpy-cpu-ref/benches/vec_znx.rs @@ -5,7 +5,7 @@ use poulpy_hal::reference::vec_znx::{bench_vec_znx_add, bench_vec_znx_automorphi #[allow(dead_code)] fn bench_vec_znx_add_cpu_ref_fft64(c: &mut Criterion) { - bench_vec_znx_add::(c, "cpu_spqlios::fft64"); + bench_vec_znx_add::(c, "cpu_ref::fft64"); } #[allow(dead_code)] diff --git a/poulpy-cpu-ref/src/convolution.rs b/poulpy-cpu-ref/src/convolution.rs new file mode 100644 index 0000000..ffe5454 --- /dev/null +++ b/poulpy-cpu-ref/src/convolution.rs @@ -0,0 +1,166 @@ +use poulpy_hal::{ + api::{Convolution, ModuleN, ScratchTakeBasic, TakeSlice, VecZnxDftApply, VecZnxDftBytesOf}, + layouts::{ + Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnx, + VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos, + }, + oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl}, + reference::fft64::convolution::{ + convolution_apply_dft, convolution_apply_dft_tmp_bytes, convolution_by_const_apply, convolution_by_const_apply_tmp_bytes, + convolution_pairwise_apply_dft, convolution_pairwise_apply_dft_tmp_bytes, convolution_prepare_left, + convolution_prepare_right, + }, +}; + +use crate::{FFT64Ref, module::FFT64ModuleHandle}; + +unsafe impl CnvPVecLAllocImpl for FFT64Ref { + fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL, Self> { + CnvPVecL::alloc(n, cols, size) + } + + fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR, Self> { + CnvPVecR::alloc(n, cols, size) + } +} + +unsafe impl CnvPVecBytesOfImpl for FFT64Ref { + fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize { + Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() + } + + fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize { + Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() + } +} + +unsafe impl ConvolutionImpl for FFT64Ref +where + Module: ModuleN + VecZnxDftBytesOf + VecZnxDftApply, +{ + fn cnv_prepare_left_tmp_bytes_impl(module: &Module, res_size: usize, a_size: usize) -> usize { + module.bytes_of_vec_znx_dft(1, res_size.min(a_size)) + } + + fn cnv_prepare_left_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: CnvPVecLToMut, + A: VecZnxToRef, + { + let res: &mut CnvPVecL<&mut [u8], FFT64Ref> = &mut res.to_mut(); + let a: &VecZnx<&[u8]> = &a.to_ref(); + let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size())); + convolution_prepare_left(module.get_fft_table(), res, a, &mut tmp); + } + + fn cnv_prepare_right_tmp_bytes_impl(module: &Module, res_size: usize, a_size: usize) -> usize { + module.bytes_of_vec_znx_dft(1, res_size.min(a_size)) + } + + fn cnv_prepare_right_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: CnvPVecRToMut, + A: VecZnxToRef, + { + let res: &mut CnvPVecR<&mut [u8], FFT64Ref> = &mut res.to_mut(); + let a: &VecZnx<&[u8]> = &a.to_ref(); + let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size())); + convolution_prepare_right(module.get_fft_table(), res, a, &mut tmp); + } + + fn cnv_apply_dft_tmp_bytes_impl( + _module: &Module, + res_size: usize, + _res_offset: usize, + a_size: usize, + b_size: usize, + ) -> usize { + convolution_apply_dft_tmp_bytes(res_size, a_size, b_size) + } + + fn cnv_by_const_apply_tmp_bytes_impl( + _module: &Module, + res_size: usize, + _res_offset: usize, + a_size: usize, + b_size: usize, + ) -> usize { + convolution_by_const_apply_tmp_bytes(res_size, a_size, b_size) + } + + fn cnv_by_const_apply_impl( + module: &Module, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &[i64], + scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let res: &mut VecZnxBig<&mut [u8], Self> = &mut res.to_mut(); + let a: &VecZnx<&[u8]> = &a.to_ref(); + let (tmp, _) = + scratch.take_slice(module.cnv_by_const_apply_tmp_bytes(res.size(), res_offset, a.size(), b.len()) / size_of::()); + convolution_by_const_apply(res, res_offset, res_col, a, a_col, b, tmp); + } + + fn cnv_apply_dft_impl( + module: &Module, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef, + { + let res: &mut VecZnxDft<&mut [u8], FFT64Ref> = &mut res.to_mut(); + let a: &CnvPVecL<&[u8], FFT64Ref> = &a.to_ref(); + let b: &CnvPVecR<&[u8], FFT64Ref> = &b.to_ref(); + let (tmp, _) = + scratch.take_slice(module.cnv_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::()); + convolution_apply_dft(res, res_offset, res_col, a, a_col, b, b_col, tmp); + } + + fn cnv_pairwise_apply_dft_tmp_bytes( + _module: &Module, + res_size: usize, + _res_offset: usize, + a_size: usize, + b_size: usize, + ) -> usize { + convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size) + } + + fn cnv_pairwise_apply_dft_impl( + module: &Module, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + b: &B, + col_0: usize, + col_1: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef, + { + let res: &mut VecZnxDft<&mut [u8], FFT64Ref> = &mut res.to_mut(); + let a: &CnvPVecL<&[u8], FFT64Ref> = &a.to_ref(); + let b: &CnvPVecR<&[u8], FFT64Ref> = &b.to_ref(); + let (tmp, _) = scratch + .take_slice(module.cnv_pairwise_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::()); + convolution_pairwise_apply_dft(res, res_offset, res_col, a, b, col_0, col_1, tmp); + } +} diff --git a/poulpy-cpu-ref/src/lib.rs b/poulpy-cpu-ref/src/lib.rs index e0110a4..6291b2e 100644 --- a/poulpy-cpu-ref/src/lib.rs +++ b/poulpy-cpu-ref/src/lib.rs @@ -1,3 +1,4 @@ +mod convolution; mod module; mod reim; mod scratch; diff --git a/poulpy-cpu-ref/src/main.rs b/poulpy-cpu-ref/src/main.rs deleted file mode 100644 index e7a11a9..0000000 --- a/poulpy-cpu-ref/src/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Hello, world!"); -} diff --git a/poulpy-cpu-ref/src/reim.rs b/poulpy-cpu-ref/src/reim.rs index 9ce2164..26fd4fd 100644 --- a/poulpy-cpu-ref/src/reim.rs +++ b/poulpy-cpu-ref/src/reim.rs @@ -1,4 +1,9 @@ use poulpy_hal::reference::fft64::{ + convolution::{ + I64ConvolutionByConst1Coeff, I64ConvolutionByConst2Coeffs, I64Extract1BlkContiguous, I64Save1BlkContiguous, + i64_convolution_by_const_1coeff_ref, i64_convolution_by_const_2coeffs_ref, i64_extract_1blk_contiguous_ref, + i64_save_1blk_contiguous_ref, + }, reim::{ ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul, ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx, @@ -8,9 +13,13 @@ use poulpy_hal::reference::fft64::{ reim_zero_ref, }, reim4::{ - Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks, - reim4_extract_1blk_from_reim_ref, reim4_save_1blk_to_reim_ref, reim4_save_2blk_to_reim_ref, - reim4_vec_mat1col_product_ref, reim4_vec_mat2cols_2ndcol_product_ref, reim4_vec_mat2cols_product_ref, + Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4ConvolutionByRealConst1Coeff, Reim4ConvolutionByRealConst2Coeffs, + Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, + Reim4Save1BlkContiguous, Reim4Save2Blks, reim4_convolution_1coeff_ref, reim4_convolution_2coeffs_ref, + reim4_convolution_by_real_const_1coeff_ref, reim4_convolution_by_real_const_2coeffs_ref, + reim4_extract_1blk_from_reim_contiguous_ref, reim4_save_1blk_to_reim_contiguous_ref, reim4_save_1blk_to_reim_ref, + reim4_save_2blk_to_reim_ref, reim4_vec_mat1col_product_ref, reim4_vec_mat2cols_2ndcol_product_ref, + reim4_vec_mat2cols_product_ref, }, }; @@ -133,10 +142,29 @@ impl ReimZero for FFT64Ref { } } -impl Reim4Extract1Blk for FFT64Ref { +impl Reim4Convolution1Coeff for FFT64Ref { + fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) { + reim4_convolution_1coeff_ref(k, dst, a, a_size, b, b_size); + } +} + +impl Reim4Convolution2Coeffs for FFT64Ref { + fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) { + reim4_convolution_2coeffs_ref(k, dst, a, a_size, b, b_size); + } +} + +impl Reim4Extract1BlkContiguous for FFT64Ref { #[inline(always)] - fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { - reim4_extract_1blk_from_reim_ref(m, rows, blk, dst, src); + fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + reim4_extract_1blk_from_reim_contiguous_ref(m, rows, blk, dst, src); + } +} + +impl Reim4Save1BlkContiguous for FFT64Ref { + #[inline(always)] + fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + reim4_save_1blk_to_reim_contiguous_ref(m, rows, blk, dst, src); } } @@ -174,3 +202,45 @@ impl Reim4Mat2Cols2ndColProd for FFT64Ref { reim4_vec_mat2cols_2ndcol_product_ref(nrows, dst, u, v); } } + +impl Reim4ConvolutionByRealConst1Coeff for FFT64Ref { + #[inline(always)] + fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) { + reim4_convolution_by_real_const_1coeff_ref(k, dst, a, a_size, b); + } +} + +impl Reim4ConvolutionByRealConst2Coeffs for FFT64Ref { + #[inline(always)] + fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) { + reim4_convolution_by_real_const_2coeffs_ref(k, dst, a, a_size, b); + } +} + +impl I64ConvolutionByConst1Coeff for FFT64Ref { + #[inline(always)] + fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) { + i64_convolution_by_const_1coeff_ref(k, dst, a, a_size, b); + } +} + +impl I64ConvolutionByConst2Coeffs for FFT64Ref { + #[inline(always)] + fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) { + i64_convolution_by_const_2coeffs_ref(k, dst, a, a_size, b); + } +} + +impl I64Save1BlkContiguous for FFT64Ref { + #[inline(always)] + fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) { + i64_save_1blk_contiguous_ref(n, offset, rows, blk, dst, src); + } +} + +impl I64Extract1BlkContiguous for FFT64Ref { + #[inline(always)] + fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) { + i64_extract_1blk_contiguous_ref(n, offset, rows, blk, dst, src); + } +} diff --git a/poulpy-cpu-ref/src/tests.rs b/poulpy-cpu-ref/src/tests.rs index 177dfb8..cf66929 100644 --- a/poulpy-cpu-ref/src/tests.rs +++ b/poulpy-cpu-ref/src/tests.rs @@ -1,9 +1,25 @@ -use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring}; +use poulpy_hal::{ + api::ModuleNew, + layouts::Module, + test_suite::convolution::{test_convolution, test_convolution_by_const, test_convolution_pairwise}, +}; use crate::FFT64Ref; +#[test] +fn test_convolution_by_const_fft64_ref() { + let module: Module = Module::::new(8); + test_convolution_by_const(&module); +} + #[test] fn test_convolution_fft64_ref() { let module: Module = Module::::new(8); - test_bivariate_tensoring(&module); + test_convolution(&module); +} + +#[test] +fn test_convolution_pairwise_fft64_ref() { + let module: Module = Module::::new(8); + test_convolution_pairwise(&module); } diff --git a/poulpy-cpu-ref/src/vec_znx.rs b/poulpy-cpu-ref/src/vec_znx.rs index 927a85e..3093257 100644 --- a/poulpy-cpu-ref/src/vec_znx.rs +++ b/poulpy-cpu-ref/src/vec_znx.rs @@ -53,11 +53,12 @@ where { fn vec_znx_normalize_impl( module: &Module, - res_basek: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_basek: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where @@ -65,7 +66,7 @@ where A: VecZnxToRef, { let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_normalize::(res_basek, res, res_col, a_basek, a, a_col, carry); + vec_znx_normalize::(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry); } } diff --git a/poulpy-cpu-ref/src/vec_znx_big.rs b/poulpy-cpu-ref/src/vec_znx_big.rs index 6f8c3f2..c7d4097 100644 --- a/poulpy-cpu-ref/src/vec_znx_big.rs +++ b/poulpy-cpu-ref/src/vec_znx_big.rs @@ -26,7 +26,7 @@ use poulpy_hal::{ source::Source, }; -unsafe impl VecZnxBigAllocBytesImpl for FFT64Ref { +unsafe impl VecZnxBigAllocBytesImpl for FFT64Ref { fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_big_word_count() * n * cols * size * size_of::() } @@ -280,11 +280,12 @@ where { fn vec_znx_big_normalize_impl( module: &Module, - res_basek: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_basek: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where @@ -292,7 +293,7 @@ where A: VecZnxBigToRef, { let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); - vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry); + vec_znx_big_normalize(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry); } } diff --git a/poulpy-hal/src/api/convolution.rs b/poulpy-hal/src/api/convolution.rs index d1c6c5e..d07f1b0 100644 --- a/poulpy-hal/src/api/convolution.rs +++ b/poulpy-hal/src/api/convolution.rs @@ -1,120 +1,97 @@ -use crate::{ - api::{ - ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace, - VecZnxDftBytesOf, VecZnxDftZero, - }, - layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos}, +use crate::layouts::{ + Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Scratch, VecZnxBigToMut, + VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxViewMut, }; -impl BivariateTensoring for Module -where - Self: BivariateConvolution, - Scratch: ScratchTakeBasic, -{ +pub trait CnvPVecAlloc { + fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL, BE>; + fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR, BE>; } -pub trait BivariateTensoring -where - Self: BivariateConvolution, - Scratch: ScratchTakeBasic, -{ - fn bivariate_tensoring(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) +pub trait CnvPVecBytesOf { + fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize; + fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize; +} + +pub trait Convolution { + fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize; + fn cnv_prepare_left(&self, res: &mut R, a: &A, scratch: &mut Scratch) where - R: VecZnxDftToMut, - A: VecZnxToRef, - B: VecZnxDftToRef, - { - let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut(); - let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref(); - let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref(); + R: CnvPVecLToMut + ZnxInfos + ZnxViewMut, + A: VecZnxToRef + ZnxInfos; - let res_cols: usize = res.cols(); - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); + fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize; + fn cnv_prepare_right(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: CnvPVecRToMut + ZnxInfos + ZnxViewMut, + A: VecZnxToRef + ZnxInfos; - assert!(res_cols >= a_cols + b_cols - 1); + fn cnv_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize; - for res_col in 0..res_cols { - self.vec_znx_dft_zero(res, res_col); - } + fn cnv_by_const_apply_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize; - for a_col in 0..a_cols { - for b_col in 0..b_cols { - self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch); - } - } - } -} - -impl BivariateConvolution for Module -where - Self: Sized - + ModuleN - + SvpPPolAlloc - + SvpApplyDftToDft - + SvpPrepare - + SvpPPolBytesOf - + VecZnxDftBytesOf - + VecZnxDftAddScaledInplace - + VecZnxDftZero, - Scratch: ScratchTakeBasic, -{ -} - -pub trait BivariateConvolution -where - Self: Sized - + ModuleN - + SvpPPolAlloc - + SvpApplyDftToDft - + SvpPrepare - + SvpPPolBytesOf - + VecZnxDftBytesOf - + VecZnxDftAddScaledInplace - + VecZnxDftZero, - Scratch: ScratchTakeBasic, -{ - fn convolution_tmp_bytes(&self, b_size: usize) -> usize { - self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size) - } + /// Evaluates a bivariate convolution over Z[X, Y] (x) Z[Y] mod (X^N + 1) where Y = 2^-K over the + /// selected columns and stores the result on the selected column, scaled by 2^{res_offset * Base2K} + /// + /// Behavior is identical to [Convolution::cnv_apply_dft] with `b` treated as a constant polynomial + /// in the X variable, for example: + ///```text + /// 1 X X^2 X^3 + /// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ... + /// Y [a01, a11, a21, a31] + /// + /// b = 1 [b0] = (b00 + b01 * 2^-K) + /// Y [b0] + /// ``` + /// This method is intended to be used for multiplications by constants that are greater than the base2k. + #[allow(clippy::too_many_arguments)] + fn cnv_by_const_apply( + &self, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &[i64], + scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + A: VecZnxToRef; #[allow(clippy::too_many_arguments)] - /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the - /// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K} + /// Evaluates a bivariate convolution over Z[X, Y] (x) Z[X, Y] mod (X^N + 1) where Y = 2^-K over the + /// selected columns and stores the result on the selected column, scaled by 2^{res_offset * Base2K} /// /// # Example - /// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ... - /// [a01, a11, a21, a31] + ///```text + /// 1 X X^2 X^3 + /// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ... + /// Y [a01, a11, a21, a31] /// - /// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ... - /// [b01, b11, b21, b31] + /// b = 1 [b00, b10, b20, b30] = (b00 + b01 * 2^-K) + (b10 + b11 * 2^-K) * X ... + /// Y [b01, b11, b21, b31] /// - /// If k = 0: - /// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ... - /// [r01, r11, r21, r31] - /// [r02, r12, r22, r32] - /// [r03, r13, r23, r33] - /// [r04, r14, r24, r34] + /// If res_offset = 0: /// - /// If k = 1: - /// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ... - /// [r02, r12, r22, r32] - /// [r03, r13, r23, r33] - /// [r04, r14, r24, r34] - /// [r05, r15, r25, r35] + /// 1 X X^2 X^3 + /// res = 1 [r00, r10, r20, r30] = (r00 + r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K) + ... * X + ... + /// Y [r01, r11, r21, r31] + /// Y^2[r02, r12, r22, r32] + /// Y^3[r03, r13, r23, r33] /// - /// If k = -1: - /// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ... - /// [ 0, 0, 0, 0] - /// [r01, r11, r21, r31] - /// [r02, r12, r22, r32] - /// [r03, r13, r23, r33] + /// If res_offset = 1: /// - /// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension. - fn bivariate_convolution_add( + /// 1 X X^2 X^3 + /// res = 1 [r01, r11, r21, r31] = (r01 + r02 * 2^-K + r03 * 2^-2K) + ... * X + ... + /// Y [r02, r12, r22, r32] + /// Y^2[r03, r13, r23, r33] + /// Y^3[ 0, 0, 0 , 0] + /// ``` + /// If res.size() < a.size() + b.size() + k, result is truncated accordingly in the Y dimension. + fn cnv_apply_dft( &self, - k: i64, res: &mut R, + res_offset: usize, res_col: usize, a: &A, a_col: usize, @@ -123,40 +100,27 @@ where scratch: &mut Scratch, ) where R: VecZnxDftToMut, - A: VecZnxToRef, - B: VecZnxDftToRef, - { - let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut(); - let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref(); - let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref(); + A: CnvPVecLToRef, + B: CnvPVecRToRef; - let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1); - let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, b.size()); - - for a_limb in 0..a.size() { - self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0); - self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col); - self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k); - } - } + fn cnv_pairwise_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize; #[allow(clippy::too_many_arguments)] - fn bivariate_convolution( + /// Evaluates the bivariate pair-wise convolution res = (a[i] + a[j]) * (b[i] + b[j]). + /// If i == j then calls [Convolution::cnv_apply_dft], i.e. res = a[i] * b[i]. + /// See [Convolution::cnv_apply_dft] for information about the bivariate convolution. + fn cnv_pairwise_apply_dft( &self, - k: i64, res: &mut R, + res_offset: usize, res_col: usize, a: &A, - a_col: usize, b: &B, - b_col: usize, + i: usize, + j: usize, scratch: &mut Scratch, ) where R: VecZnxDftToMut, - A: VecZnxToRef, - B: VecZnxDftToRef, - { - self.vec_znx_dft_zero(res, res_col); - self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch); - } + A: CnvPVecLToRef, + B: CnvPVecRToRef; } diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index 4dbb14b..4db7627 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -1,6 +1,6 @@ use crate::{ - api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, + api::{CnvPVecBytesOf, ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, + layouts::{Backend, CnvPVecL, CnvPVecR, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, }; /// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes. @@ -56,6 +56,22 @@ pub trait ScratchTakeBasic where Self: TakeSlice, { + fn take_cnv_pvec_left(&mut self, module: &M, cols: usize, size: usize) -> (CnvPVecL<&mut [u8], B>, &mut Self) + where + M: ModuleN + CnvPVecBytesOf, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_cnv_pvec_left(cols, size)); + (CnvPVecL::from_data(take_slice, module.n(), cols, size), rem_slice) + } + + fn take_cnv_pvec_right(&mut self, module: &M, cols: usize, size: usize) -> (CnvPVecR<&mut [u8], B>, &mut Self) + where + M: ModuleN + CnvPVecBytesOf, + { + let (take_slice, rem_slice) = self.take_slice(module.bytes_of_cnv_pvec_right(cols, size)); + (CnvPVecR::from_data(take_slice, module.n(), cols, size), rem_slice) + } + fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols)); (ScalarZnx::from_data(take_slice, n, cols), rem_slice) @@ -79,10 +95,7 @@ where M: VecZnxBigBytesOf + ModuleN, { let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size)); - ( - VecZnxBig::from_data(take_slice, module.n(), cols, size), - rem_slice, - ) + (VecZnxBig::from_data(take_slice, module.n(), cols, size), rem_slice) } fn take_vec_znx_dft(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) @@ -91,10 +104,7 @@ where { let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size)); - ( - VecZnxDft::from_data(take_slice, module.n(), cols, size), - rem_slice, - ) + (VecZnxDft::from_data(take_slice, module.n(), cols, size), rem_slice) } fn take_vec_znx_dft_slice( @@ -155,9 +165,6 @@ where size: usize, ) -> (MatZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size)); - ( - MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), - rem_slice, - ) + (MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), rem_slice) } } diff --git a/poulpy-hal/src/api/vec_znx.rs b/poulpy-hal/src/api/vec_znx.rs index 8bf0e65..4e1e686 100644 --- a/poulpy-hal/src/api/vec_znx.rs +++ b/poulpy-hal/src/api/vec_znx.rs @@ -19,11 +19,12 @@ pub trait VecZnxNormalize { /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. fn vec_znx_normalize( &self, - res_basek: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_basek: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where diff --git a/poulpy-hal/src/api/vec_znx_big.rs b/poulpy-hal/src/api/vec_znx_big.rs index 2cf9bba..591445c 100644 --- a/poulpy-hal/src/api/vec_znx_big.rs +++ b/poulpy-hal/src/api/vec_znx_big.rs @@ -164,11 +164,12 @@ pub trait VecZnxBigNormalizeTmpBytes { pub trait VecZnxBigNormalize { fn vec_znx_big_normalize( &self, - res_base2k: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_base2k: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where diff --git a/poulpy-hal/src/bench_suite/convolution.rs b/poulpy-hal/src/bench_suite/convolution.rs new file mode 100644 index 0000000..3d6c4f1 --- /dev/null +++ b/poulpy-hal/src/bench_suite/convolution.rs @@ -0,0 +1,268 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; + +use crate::{ + api::{CnvPVecAlloc, Convolution, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxDftAlloc}, + layouts::{Backend, CnvPVecL, CnvPVecR, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig}, + source::Source, +}; + +pub fn bench_cnv_prepare_left(c: &mut Criterion, label: &str) +where + Module: ModuleNew + Convolution + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("cnv_prepare_left::{label}"); + + let mut group = c.benchmark_group(group_name); + + fn runner(n: usize, size: usize) -> impl FnMut() + where + Module: ModuleNew + Convolution + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let mut source: Source = Source::new([0u8; 32]); + + let base2k: usize = 12; + + let c_size: usize = size + size - 1; + + let module: Module = Module::::new(n as u64); + + let mut a_prep: CnvPVecL, BE> = module.cnv_pvec_left_alloc(1, size); + + let mut a: VecZnx> = VecZnx::alloc(module.n(), 1, size); + + a.fill_uniform(base2k, &mut source); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.cnv_prepare_left_tmp_bytes(c_size, size)); + + move || { + module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow()); + black_box(()); + } + } + + for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] { + let log_n: usize = params[0]; + let size: usize = params[1]; + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size)); + let mut runner = runner(1 << log_n, size); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_cnv_prepare_right(c: &mut Criterion, label: &str) +where + Module: ModuleNew + Convolution + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("cnv_prepare_right::{label}"); + + let mut group = c.benchmark_group(group_name); + + fn runner(n: usize, size: usize) -> impl FnMut() + where + Module: ModuleNew + Convolution + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let mut source: Source = Source::new([0u8; 32]); + + let base2k: usize = 12; + + let c_size: usize = size + size - 1; + + let module: Module = Module::::new(n as u64); + + let mut a_prep: CnvPVecR, BE> = module.cnv_pvec_right_alloc(1, size); + + let mut a: VecZnx> = VecZnx::alloc(module.n(), 1, size); + + a.fill_uniform(base2k, &mut source); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.cnv_prepare_right_tmp_bytes(c_size, size)); + + move || { + module.cnv_prepare_right(&mut a_prep, &a, scratch.borrow()); + black_box(()); + } + } + + for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] { + let log_n: usize = params[0]; + let size: usize = params[1]; + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size)); + let mut runner = runner(1 << log_n, size); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_cnv_apply_dft(c: &mut Criterion, label: &str) +where + Module: ModuleNew + Convolution + VecZnxDftAlloc + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("cnv_apply_dft::{label}"); + + let mut group = c.benchmark_group(group_name); + + fn runner(n: usize, size: usize) -> impl FnMut() + where + Module: ModuleNew + Convolution + VecZnxDftAlloc + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let mut source: Source = Source::new([0u8; 32]); + + let base2k: usize = 12; + + let c_size: usize = size + size - 1; + + let module: Module = Module::::new(n as u64); + + let mut a_prep: CnvPVecL, BE> = module.cnv_pvec_left_alloc(1, size); + let mut b_prep: CnvPVecR, BE> = module.cnv_pvec_right_alloc(1, size); + + let mut a: VecZnx> = VecZnx::alloc(module.n(), 1, size); + let mut b: VecZnx> = VecZnx::alloc(module.n(), 1, size); + let mut c_dft = module.vec_znx_dft_alloc(1, c_size); + + a.fill_uniform(base2k, &mut source); + b.fill_uniform(base2k, &mut source); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + module + .cnv_apply_dft_tmp_bytes(c_size, 0, size, size) + .max(module.cnv_prepare_left_tmp_bytes(c_size, size)) + .max(module.cnv_prepare_right_tmp_bytes(c_size, size)), + ); + module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow()); + module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow()); + move || { + module.cnv_apply_dft(&mut c_dft, 0, 0, &a_prep, 0, &b_prep, 0, scratch.borrow()); + black_box(()); + } + } + + for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] { + let log_n: usize = params[0]; + let size: usize = params[1]; + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size)); + let mut runner = runner(1 << log_n, size); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_cnv_pairwise_apply_dft(c: &mut Criterion, label: &str) +where + Module: ModuleNew + Convolution + VecZnxDftAlloc + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("cnv_pairwise_apply_dft::{label}"); + + let mut group = c.benchmark_group(group_name); + + fn runner(n: usize, size: usize) -> impl FnMut() + where + Module: ModuleNew + Convolution + VecZnxDftAlloc + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let mut source: Source = Source::new([0u8; 32]); + + let base2k: usize = 12; + + let module: Module = Module::::new(n as u64); + + let cols = 2; + let c_size: usize = size + size - 1; + + let mut a_prep: CnvPVecL, BE> = module.cnv_pvec_left_alloc(cols, size); + let mut b_prep: CnvPVecR, BE> = module.cnv_pvec_right_alloc(cols, size); + + let mut a: VecZnx> = VecZnx::alloc(module.n(), cols, size); + let mut b: VecZnx> = VecZnx::alloc(module.n(), cols, size); + let mut c_dft = module.vec_znx_dft_alloc(1, c_size); + + a.fill_uniform(base2k, &mut source); + b.fill_uniform(base2k, &mut source); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + module + .cnv_pairwise_apply_dft_tmp_bytes(c_size, 0, size, size) + .max(module.cnv_prepare_left_tmp_bytes(c_size, size)) + .max(module.cnv_prepare_right_tmp_bytes(c_size, size)), + ); + module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow()); + module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow()); + move || { + module.cnv_pairwise_apply_dft(&mut c_dft, 0, 0, &a_prep, &b_prep, 0, 1, scratch.borrow()); + black_box(()); + } + } + + for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] { + let log_n: usize = params[0]; + let size: usize = params[1]; + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size)); + let mut runner = runner(1 << log_n, size); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_cnv_by_const_apply(c: &mut Criterion, label: &str) +where + Module: ModuleNew + Convolution + VecZnxBigAlloc + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("cnv_by_const::{label}"); + + let mut group = c.benchmark_group(group_name); + + fn runner(n: usize, size: usize) -> impl FnMut() + where + Module: ModuleNew + Convolution + VecZnxBigAlloc + CnvPVecAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let mut source: Source = Source::new([0u8; 32]); + + let base2k: usize = 12; + + let module: Module = Module::::new(n as u64); + + let cols = 2; + let c_size: usize = size + size - 1; + + let mut a: VecZnx> = VecZnx::alloc(module.n(), cols, size); + let mut c_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, c_size); + + a.fill_uniform(base2k, &mut source); + let mut b = vec![0i64; size]; + for x in &mut b { + *x = source.next_i64(); + } + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.cnv_by_const_apply_tmp_bytes(c_size, 0, size, size)); + move || { + module.cnv_by_const_apply(&mut c_big, 0, 0, &a, 0, &b, scratch.borrow()); + black_box(()); + } + } + + for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] { + let log_n: usize = params[0]; + let size: usize = params[1]; + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size)); + let mut runner = runner(1 << log_n, size); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/bench_suite/mod.rs b/poulpy-hal/src/bench_suite/mod.rs index 57b60b4..af09a1a 100644 --- a/poulpy-hal/src/bench_suite/mod.rs +++ b/poulpy-hal/src/bench_suite/mod.rs @@ -1,3 +1,4 @@ +pub mod convolution; pub mod svp; pub mod vec_znx; pub mod vec_znx_big; diff --git a/poulpy-hal/src/bench_suite/vec_znx_big.rs b/poulpy-hal/src/bench_suite/vec_znx_big.rs index 01e6812..24f06a9 100644 --- a/poulpy-hal/src/bench_suite/vec_znx_big.rs +++ b/poulpy-hal/src/bench_suite/vec_znx_big.rs @@ -404,7 +404,7 @@ where move || { for i in 0..cols { - module.vec_znx_big_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow()); + module.vec_znx_big_normalize(&mut res, base2k, 0, i, &a, base2k, i, scratch.borrow()); } black_box(()); } diff --git a/poulpy-hal/src/delegates/convolution.rs b/poulpy-hal/src/delegates/convolution.rs new file mode 100644 index 0000000..3250e85 --- /dev/null +++ b/poulpy-hal/src/delegates/convolution.rs @@ -0,0 +1,125 @@ +use crate::{ + api::{CnvPVecAlloc, CnvPVecBytesOf, Convolution}, + layouts::{ + Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnxBigToMut, + VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxViewMut, + }, + oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl}, +}; + +impl CnvPVecAlloc for Module +where + BE: CnvPVecLAllocImpl, +{ + fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL, BE> { + BE::cnv_pvec_left_alloc_impl(self.n(), cols, size) + } + + fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR, BE> { + BE::cnv_pvec_right_alloc_impl(self.n(), cols, size) + } +} + +impl CnvPVecBytesOf for Module +where + BE: CnvPVecBytesOfImpl, +{ + fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize { + BE::bytes_of_cnv_pvec_left_impl(self.n(), cols, size) + } + + fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize { + BE::bytes_of_cnv_pvec_right_impl(self.n(), cols, size) + } +} + +impl Convolution for Module +where + BE: ConvolutionImpl, +{ + fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize { + BE::cnv_prepare_left_tmp_bytes_impl(self, res_size, a_size) + } + fn cnv_prepare_left(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: CnvPVecLToMut + ZnxInfos + ZnxViewMut, + A: VecZnxToRef + ZnxInfos, + { + BE::cnv_prepare_left_impl(self, res, a, scratch); + } + + fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize { + BE::cnv_prepare_right_tmp_bytes_impl(self, res_size, a_size) + } + fn cnv_prepare_right(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: CnvPVecRToMut + ZnxInfos + ZnxViewMut, + A: VecZnxToRef + ZnxInfos, + { + BE::cnv_prepare_right_impl(self, res, a, scratch); + } + + fn cnv_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize { + BE::cnv_apply_dft_tmp_bytes_impl(self, res_size, res_offset, a_size, b_size) + } + + fn cnv_by_const_apply_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize { + BE::cnv_by_const_apply_tmp_bytes_impl(self, res_size, res_offset, a_size, b_size) + } + + fn cnv_by_const_apply( + &self, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &[i64], + scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + BE::cnv_by_const_apply_impl(self, res, res_offset, res_col, a, a_col, b, scratch); + } + + fn cnv_apply_dft( + &self, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef, + { + BE::cnv_apply_dft_impl(self, res, res_offset, res_col, a, a_col, b, b_col, scratch); + } + + fn cnv_pairwise_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize { + BE::cnv_pairwise_apply_dft_tmp_bytes(self, res_size, res_offset, a_size, b_size) + } + + fn cnv_pairwise_apply_dft( + &self, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + b: &B, + i: usize, + j: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef, + { + BE::cnv_pairwise_apply_dft_impl(self, res, res_offset, res_col, a, b, i, j, scratch); + } +} diff --git a/poulpy-hal/src/delegates/mod.rs b/poulpy-hal/src/delegates/mod.rs index 595a641..d4200c7 100644 --- a/poulpy-hal/src/delegates/mod.rs +++ b/poulpy-hal/src/delegates/mod.rs @@ -1,3 +1,4 @@ +mod convolution; mod module; mod scratch; mod svp_ppol; diff --git a/poulpy-hal/src/delegates/vec_znx.rs b/poulpy-hal/src/delegates/vec_znx.rs index 02f512a..9fc3ef3 100644 --- a/poulpy-hal/src/delegates/vec_znx.rs +++ b/poulpy-hal/src/delegates/vec_znx.rs @@ -51,18 +51,19 @@ where #[allow(clippy::too_many_arguments)] fn vec_znx_normalize( &self, - res_basek: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_basek: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where R: VecZnxToMut, A: VecZnxToRef, { - B::vec_znx_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch) + B::vec_znx_normalize_impl(self, res, res_base2k, res_offset, res_col, a, a_base2k, a_col, scratch) } } diff --git a/poulpy-hal/src/delegates/vec_znx_big.rs b/poulpy-hal/src/delegates/vec_znx_big.rs index a1cc307..d3eb3ee 100644 --- a/poulpy-hal/src/delegates/vec_znx_big.rs +++ b/poulpy-hal/src/delegates/vec_znx_big.rs @@ -51,7 +51,7 @@ where impl VecZnxBigBytesOf for Module where - B: Backend + VecZnxBigAllocBytesImpl, + B: Backend + VecZnxBigAllocBytesImpl, { fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { B::vec_znx_big_bytes_of_impl(self.n(), cols, size) @@ -264,18 +264,19 @@ where { fn vec_znx_big_normalize( &self, - res_basek: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_basek: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where R: VecZnxToMut, A: VecZnxBigToRef, { - B::vec_znx_big_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch); + B::vec_znx_big_normalize_impl(self, res, res_base2k, res_offset, res_col, a, a_base2k, a_col, scratch); } } diff --git a/poulpy-hal/src/delegates/vmp_pmat.rs b/poulpy-hal/src/delegates/vmp_pmat.rs index 69598cb..6a3ec1e 100644 --- a/poulpy-hal/src/delegates/vmp_pmat.rs +++ b/poulpy-hal/src/delegates/vmp_pmat.rs @@ -76,9 +76,7 @@ where b_cols_out: usize, b_size: usize, ) -> usize { - B::vmp_apply_dft_tmp_bytes_impl( - self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, - ) + B::vmp_apply_dft_tmp_bytes_impl(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size) } } @@ -109,9 +107,7 @@ where b_cols_out: usize, b_size: usize, ) -> usize { - B::vmp_apply_dft_to_dft_tmp_bytes_impl( - self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, - ) + B::vmp_apply_dft_to_dft_tmp_bytes_impl(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size) } } @@ -142,9 +138,7 @@ where b_cols_out: usize, b_size: usize, ) -> usize { - B::vmp_apply_dft_to_dft_add_tmp_bytes_impl( - self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, - ) + B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size) } } diff --git a/poulpy-hal/src/layouts/convolution.rs b/poulpy-hal/src/layouts/convolution.rs new file mode 100644 index 0000000..bd2d0be --- /dev/null +++ b/poulpy-hal/src/layouts/convolution.rs @@ -0,0 +1,237 @@ +use std::marker::PhantomData; + +use crate::{ + alloc_aligned, + layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ZnxInfos, ZnxView}, + oep::CnvPVecBytesOfImpl, +}; + +pub struct CnvPVecR { + data: D, + n: usize, + size: usize, + cols: usize, + _phantom: PhantomData, +} + +impl ZnxInfos for CnvPVecR { + fn cols(&self) -> usize { + self.cols + } + + fn n(&self) -> usize { + self.n + } + + fn rows(&self) -> usize { + 1 + } + + fn size(&self) -> usize { + self.size + } +} + +impl DataView for CnvPVecR { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for CnvPVecR { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl ZnxView for CnvPVecR { + type Scalar = BE::ScalarPrep; +} + +impl>, B: Backend> CnvPVecR +where + B: CnvPVecBytesOfImpl, +{ + pub fn alloc(n: usize, cols: usize, size: usize) -> Self { + let data: Vec = alloc_aligned::(B::bytes_of_cnv_pvec_right_impl(n, cols, size)); + Self { + data: data.into(), + n, + size, + cols, + _phantom: PhantomData, + } + } + + pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == B::bytes_of_cnv_pvec_right_impl(n, cols, size)); + Self { + data: data.into(), + n, + size, + cols, + _phantom: PhantomData, + } + } +} + +impl CnvPVecR { + pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + _phantom: PhantomData, + } + } +} + +pub struct CnvPVecL { + data: D, + n: usize, + size: usize, + cols: usize, + _phantom: PhantomData, +} + +impl ZnxInfos for CnvPVecL { + fn cols(&self) -> usize { + self.cols + } + + fn n(&self) -> usize { + self.n + } + + fn rows(&self) -> usize { + 1 + } + + fn size(&self) -> usize { + self.size + } +} + +impl DataView for CnvPVecL { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for CnvPVecL { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl ZnxView for CnvPVecL { + type Scalar = BE::ScalarPrep; +} + +impl>, B: Backend> CnvPVecL +where + B: CnvPVecBytesOfImpl, +{ + pub fn alloc(n: usize, cols: usize, size: usize) -> Self { + let data: Vec = alloc_aligned::(B::bytes_of_cnv_pvec_left_impl(n, cols, size)); + Self { + data: data.into(), + n, + size, + cols, + _phantom: PhantomData, + } + } + + pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == B::bytes_of_cnv_pvec_left_impl(n, cols, size)); + Self { + data: data.into(), + n, + size, + cols, + _phantom: PhantomData, + } + } +} + +impl CnvPVecL { + pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + _phantom: PhantomData, + } + } +} + +pub trait CnvPVecRToRef { + fn to_ref(&self) -> CnvPVecR<&[u8], BE>; +} + +impl CnvPVecRToRef for CnvPVecR { + fn to_ref(&self) -> CnvPVecR<&[u8], BE> { + CnvPVecR { + data: self.data.as_ref(), + n: self.n, + size: self.size, + cols: self.cols, + _phantom: self._phantom, + } + } +} + +pub trait CnvPVecRToMut { + fn to_mut(&mut self) -> CnvPVecR<&mut [u8], BE>; +} + +impl CnvPVecRToMut for CnvPVecR { + fn to_mut(&mut self) -> CnvPVecR<&mut [u8], BE> { + CnvPVecR { + data: self.data.as_mut(), + n: self.n, + size: self.size, + cols: self.cols, + _phantom: self._phantom, + } + } +} + +pub trait CnvPVecLToRef { + fn to_ref(&self) -> CnvPVecL<&[u8], BE>; +} + +impl CnvPVecLToRef for CnvPVecL { + fn to_ref(&self) -> CnvPVecL<&[u8], BE> { + CnvPVecL { + data: self.data.as_ref(), + n: self.n, + size: self.size, + cols: self.cols, + _phantom: self._phantom, + } + } +} + +pub trait CnvPVecLToMut { + fn to_mut(&mut self) -> CnvPVecL<&mut [u8], BE>; +} + +impl CnvPVecLToMut for CnvPVecL { + fn to_mut(&mut self) -> CnvPVecL<&mut [u8], BE> { + CnvPVecL { + data: self.data.as_mut(), + n: self.n, + size: self.size, + cols: self.cols, + _phantom: self._phantom, + } + } +} diff --git a/poulpy-hal/src/layouts/encoding.rs b/poulpy-hal/src/layouts/encoding.rs index 6934eec..51366c6 100644 --- a/poulpy-hal/src/layouts/encoding.rs +++ b/poulpy-hal/src/layouts/encoding.rs @@ -223,22 +223,22 @@ impl VecZnx { let a: VecZnx<&[u8]> = self.to_ref(); let size: usize = a.size(); - let prec: u32 = (base2k * size) as u32; + let prec: u32 = data[0].prec(); // 2^{base2k} - let base: Float = Float::with_val(prec, (1u64 << base2k) as f64); + let scale: Float = Float::with_val(prec, Float::u_pow_u(2, base2k as u32)); // y[i] = sum x[j][i] * 2^{-base2k*j} (0..size).for_each(|i| { if i == 0 { izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { y.assign(*x); - *y /= &base; + *y /= &scale; }); } else { izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { *y += Float::with_val(prec, *x); - *y /= &base; + *y /= &scale; }); } }); diff --git a/poulpy-hal/src/layouts/mod.rs b/poulpy-hal/src/layouts/mod.rs index d164234..9ece7e3 100644 --- a/poulpy-hal/src/layouts/mod.rs +++ b/poulpy-hal/src/layouts/mod.rs @@ -1,3 +1,4 @@ +mod convolution; mod encoding; mod mat_znx; mod module; @@ -12,6 +13,7 @@ mod vec_znx_dft; mod vmp_pmat; mod znx_base; +pub use convolution::*; pub use mat_znx::*; pub use module::*; pub use scalar_znx::*; diff --git a/poulpy-hal/src/layouts/module.rs b/poulpy-hal/src/layouts/module.rs index 2195e49..0422707 100644 --- a/poulpy-hal/src/layouts/module.rs +++ b/poulpy-hal/src/layouts/module.rs @@ -123,10 +123,8 @@ where panic!("cannot invert 0") } - let g_exp: u64 = mod_exp_u64( - gal_el.unsigned_abs(), - (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1) as u64; + let g_exp: u64 = + mod_exp_u64(gal_el.unsigned_abs(), (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64; g_exp as i64 * gal_el.signum() } } diff --git a/poulpy-hal/src/layouts/vec_znx.rs b/poulpy-hal/src/layouts/vec_znx.rs index 1435243..14346a7 100644 --- a/poulpy-hal/src/layouts/vec_znx.rs +++ b/poulpy-hal/src/layouts/vec_znx.rs @@ -187,11 +187,7 @@ impl VecZnx { impl fmt::Display for VecZnx { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "VecZnx(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; + writeln!(f, "VecZnx(n={}, cols={}, size={})", self.n, self.cols, self.size)?; for col in 0..self.cols { writeln!(f, "Column {col}:")?; diff --git a/poulpy-hal/src/layouts/vec_znx_big.rs b/poulpy-hal/src/layouts/vec_znx_big.rs index 73a3e0f..e748517 100644 --- a/poulpy-hal/src/layouts/vec_znx_big.rs +++ b/poulpy-hal/src/layouts/vec_znx_big.rs @@ -93,7 +93,7 @@ where impl>, B: Backend> VecZnxBig where - B: VecZnxBigAllocBytesImpl, + B: VecZnxBigAllocBytesImpl, { pub fn alloc(n: usize, cols: usize, size: usize) -> Self { let data = alloc_aligned::(B::vec_znx_big_bytes_of_impl(n, cols, size)); @@ -172,11 +172,7 @@ impl VecZnxBigToMut for VecZnxBig { impl fmt::Display for VecZnxBig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "VecZnxBig(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; + writeln!(f, "VecZnxBig(n={}, cols={}, size={})", self.n, self.cols, self.size)?; for col in 0..self.cols { writeln!(f, "Column {col}:")?; diff --git a/poulpy-hal/src/layouts/vec_znx_dft.rs b/poulpy-hal/src/layouts/vec_znx_dft.rs index 19d28e1..6c5aba5 100644 --- a/poulpy-hal/src/layouts/vec_znx_dft.rs +++ b/poulpy-hal/src/layouts/vec_znx_dft.rs @@ -192,11 +192,7 @@ impl VecZnxDftToMut for VecZnxDft { impl fmt::Display for VecZnxDft { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "VecZnxDft(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; + writeln!(f, "VecZnxDft(n={}, cols={}, size={})", self.n, self.cols, self.size)?; for col in 0..self.cols { writeln!(f, "Column {col}:")?; diff --git a/poulpy-hal/src/layouts/znx_base.rs b/poulpy-hal/src/layouts/znx_base.rs index a2c5dd3..62288af 100644 --- a/poulpy-hal/src/layouts/znx_base.rs +++ b/poulpy-hal/src/layouts/znx_base.rs @@ -65,11 +65,8 @@ pub trait ZnxView: ZnxInfos + DataView { /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { - #[cfg(debug_assertions)] - { - assert!(i < self.cols(), "cols: {} >= {}", i, self.cols()); - assert!(j < self.size(), "size: {} >= {}", j, self.size()); - } + assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols()); + assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size()); let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_ptr().add(offset) } } @@ -93,11 +90,8 @@ pub trait ZnxViewMut: ZnxView + DataViewMut { /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { - #[cfg(debug_assertions)] - { - assert!(i < self.cols(), "cols: {} >= {}", i, self.cols()); - assert!(j < self.size(), "size: {} >= {}", j, self.size()); - } + assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols()); + assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size()); let offset: usize = self.n() * (j * self.cols() + i); unsafe { self.as_mut_ptr().add(offset) } } diff --git a/poulpy-hal/src/lib.rs b/poulpy-hal/src/lib.rs index 92d874f..4eae149 100644 --- a/poulpy-hal/src/lib.rs +++ b/poulpy-hal/src/lib.rs @@ -54,10 +54,7 @@ pub fn cast_mut(data: &[T]) -> &mut [V] { /// Alignement must be a power of two and size a multiple of the alignement. /// Allocated memory is initialized to zero. fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { - assert!( - align.is_power_of_two(), - "Alignment must be a power of two but is {align}" - ); + assert!(align.is_power_of_two(), "Alignment must be a power of two but is {align}"); assert_eq!( (size * size_of::()) % align, 0, @@ -82,10 +79,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { /// Allocates a block of T aligned with [DEFAULTALIGN]. /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { - assert!( - align.is_power_of_two(), - "Alignment must be a power of two but is {align}" - ); + assert!(align.is_power_of_two(), "Alignment must be a power of two but is {align}"); assert_eq!( (size * size_of::()) % align, diff --git a/poulpy-hal/src/oep/convolution.rs b/poulpy-hal/src/oep/convolution.rs new file mode 100644 index 0000000..91d890e --- /dev/null +++ b/poulpy-hal/src/oep/convolution.rs @@ -0,0 +1,106 @@ +use crate::layouts::{ + Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnxBigToMut, + VecZnxDftToMut, VecZnxToRef, ZnxInfos, +}; + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See the TODO reference implementation. +/// * See [crate::api::CnvPVecLAlloc] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait CnvPVecLAllocImpl { + fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL, BE>; + fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR, BE>; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See the TODO reference implementation. +/// * See [crate::api::CnvPVecLBytesOf] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait CnvPVecBytesOfImpl { + fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize; + fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See the TODO reference implementation. +/// * See [crate::api::Convolution] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait ConvolutionImpl { + fn cnv_prepare_left_tmp_bytes_impl(module: &Module, res_size: usize, a_size: usize) -> usize; + fn cnv_prepare_left_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: CnvPVecLToMut, + A: VecZnxToRef; + fn cnv_prepare_right_tmp_bytes_impl(module: &Module, res_size: usize, a_size: usize) -> usize; + fn cnv_prepare_right_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: CnvPVecRToMut, + A: VecZnxToRef + ZnxInfos; + fn cnv_apply_dft_tmp_bytes_impl( + module: &Module, + res_size: usize, + res_offset: usize, + a_size: usize, + b_size: usize, + ) -> usize; + + fn cnv_by_const_apply_tmp_bytes_impl( + module: &Module, + res_size: usize, + res_offset: usize, + a_size: usize, + b_size: usize, + ) -> usize; + + #[allow(clippy::too_many_arguments)] + fn cnv_by_const_apply_impl( + module: &Module, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &[i64], + scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + A: VecZnxToRef; + + #[allow(clippy::too_many_arguments)] + fn cnv_apply_dft_impl( + module: &Module, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef; + fn cnv_pairwise_apply_dft_tmp_bytes( + module: &Module, + res_size: usize, + res_offset: usize, + a_size: usize, + b_size: usize, + ) -> usize; + #[allow(clippy::too_many_arguments)] + fn cnv_pairwise_apply_dft_impl( + module: &Module, + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + b: &B, + i: usize, + j: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef; +} diff --git a/poulpy-hal/src/oep/mod.rs b/poulpy-hal/src/oep/mod.rs index bc53c0e..9af22de 100644 --- a/poulpy-hal/src/oep/mod.rs +++ b/poulpy-hal/src/oep/mod.rs @@ -1,3 +1,4 @@ +mod convolution; mod module; mod scratch; mod svp_ppol; @@ -6,6 +7,7 @@ mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; +pub use convolution::*; pub use module::*; pub use scratch::*; pub use svp_ppol::*; diff --git a/poulpy-hal/src/oep/vec_znx.rs b/poulpy-hal/src/oep/vec_znx.rs index 47bc94a..f5cf68f 100644 --- a/poulpy-hal/src/oep/vec_znx.rs +++ b/poulpy-hal/src/oep/vec_znx.rs @@ -29,11 +29,12 @@ pub unsafe trait VecZnxNormalizeImpl { #[allow(clippy::too_many_arguments)] fn vec_znx_normalize_impl( module: &Module, - res_basek: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_basek: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where diff --git a/poulpy-hal/src/oep/vec_znx_big.rs b/poulpy-hal/src/oep/vec_znx_big.rs index 4c12e6a..2b78f47 100644 --- a/poulpy-hal/src/oep/vec_znx_big.rs +++ b/poulpy-hal/src/oep/vec_znx_big.rs @@ -34,7 +34,7 @@ pub unsafe trait VecZnxBigFromBytesImpl { /// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. /// * See [crate::api::VecZnxBigAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxBigAllocBytesImpl { +pub unsafe trait VecZnxBigAllocBytesImpl { fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize; } @@ -46,7 +46,7 @@ pub unsafe trait VecZnxBigAllocBytesImpl { pub unsafe trait VecZnxBigAddNormalImpl { fn add_normal_impl>( module: &Module, - res_basek: usize, + res_base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -240,11 +240,12 @@ pub unsafe trait VecZnxBigNormalizeTmpBytesImpl { pub unsafe trait VecZnxBigNormalizeImpl { fn vec_znx_big_normalize_impl( module: &Module, - res_basek: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_basek: usize, a: &A, + a_base2k: usize, a_col: usize, scratch: &mut Scratch, ) where diff --git a/poulpy-hal/src/reference/fft64/convolution.rs b/poulpy-hal/src/reference/fft64/convolution.rs new file mode 100644 index 0000000..3d2d4b1 --- /dev/null +++ b/poulpy-hal/src/reference/fft64/convolution.rs @@ -0,0 +1,405 @@ +use crate::{ + layouts::{ + Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, VecZnx, VecZnxBig, + VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, + }, + reference::fft64::{ + reim::{ReimAdd, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero}, + reim4::{ + Reim4Convolution, Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4Extract1BlkContiguous, + Reim4Save1BlkContiguous, + }, + vec_znx_dft::vec_znx_dft_apply, + }, +}; + +pub fn convolution_prepare_left(table: &ReimFFTTable, res: &mut R, a: &A, tmp: &mut T) +where + BE: Backend + + ReimZero + + Reim4Extract1BlkContiguous + + ReimDFTExecute, f64> + + ReimFromZnx + + ReimZero, + R: CnvPVecLToMut + ZnxInfos + ZnxViewMut, + A: VecZnxToRef, + T: VecZnxDftToMut, +{ + convolution_prepare(table, res, a, tmp) +} + +pub fn convolution_prepare_right(table: &ReimFFTTable, res: &mut R, a: &A, tmp: &mut T) +where + BE: Backend + + ReimZero + + Reim4Extract1BlkContiguous + + ReimDFTExecute, f64> + + ReimFromZnx + + ReimZero, + R: CnvPVecRToMut + ZnxInfos + ZnxViewMut, + A: VecZnxToRef, + T: VecZnxDftToMut, +{ + convolution_prepare(table, res, a, tmp) +} + +fn convolution_prepare(table: &ReimFFTTable, res: &mut R, a: &A, tmp: &mut T) +where + BE: Backend + + ReimZero + + Reim4Extract1BlkContiguous + + ReimDFTExecute, f64> + + ReimFromZnx + + ReimZero, + R: ZnxInfos + ZnxViewMut, + A: VecZnxToRef, + T: VecZnxDftToMut, +{ + let a: &VecZnx<&[u8]> = &a.to_ref(); + let tmp: &mut VecZnxDft<&mut [u8], BE> = &mut tmp.to_mut(); + + let cols: usize = res.cols(); + assert_eq!(a.cols(), cols, "a.cols():{} != res.cols():{cols}", a.cols()); + + let res_size: usize = res.size(); + let min_size: usize = res_size.min(a.size()); + + let m: usize = a.n() >> 1; + + let n: usize = table.m() << 1; + + let res_raw: &mut [f64] = res.raw_mut(); + + for i in 0..cols { + vec_znx_dft_apply(table, 1, 0, tmp, 0, a, i); + + let tmp_raw: &[f64] = tmp.raw(); + let res_col: &mut [f64] = &mut res_raw[i * n * res_size..]; + + for blk_i in 0..m / 4 { + BE::reim4_extract_1blk_contiguous(m, min_size, blk_i, &mut res_col[blk_i * res_size * 8..], tmp_raw); + BE::reim_zero(&mut res_col[blk_i * res_size * 8 + min_size * 8..(blk_i + 1) * res_size * 8]); + } + } +} + +pub fn convolution_by_const_apply_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize { + let min_size: usize = res_size.min(a_size + b_size - 1); + size_of::() * (min_size + a_size) * 8 +} + +pub fn convolution_by_const_apply( + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &[i64], + tmp: &mut [i64], +) where + BE: Backend + + I64ConvolutionByConst1Coeff + + I64ConvolutionByConst2Coeffs + + I64Extract1BlkContiguous + + I64Save1BlkContiguous, + R: VecZnxBigToMut, + A: VecZnxToRef, +{ + let res: &mut VecZnxBig<&mut [u8], BE> = &mut res.to_mut(); + let a: &VecZnx<&[u8]> = &a.to_ref(); + + let n: usize = res.n(); + assert_eq!(a.n(), n); + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + let b_size: usize = b.len(); + + let bound: usize = a_size + b_size - 1; + let min_size: usize = res_size.min(bound); + let offset: usize = res_offset.min(bound); + + let a_sl: usize = n * a.cols(); + let res_sl: usize = n * res.cols(); + + let res_raw: &mut [i64] = res.raw_mut(); + let a_raw: &[i64] = a.raw(); + + let a_idx: usize = n * a_col; + let res_idx: usize = n * res_col; + + let (res_blk, a_blk) = tmp[..(min_size + a_size) * 8].split_at_mut(min_size * 8); + + for blk_i in 0..n / 8 { + BE::i64_extract_1blk_contiguous(a_sl, a_idx, a_size, blk_i, a_blk, a_raw); + BE::i64_convolution_by_const(res_blk, min_size, offset, a_blk, a_size, b); + BE::i64_save_1blk_contiguous(res_sl, res_idx, min_size, blk_i, res_raw, res_blk); + } + + for j in min_size..res_size { + res.zero_at(res_col, j); + } +} + +pub fn convolution_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize { + let min_size: usize = res_size.min(a_size + b_size - 1); + size_of::() * 8 * min_size +} + +#[allow(clippy::too_many_arguments)] +pub fn convolution_apply_dft( + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + tmp: &mut [f64], +) where + BE: Backend + Reim4Save1BlkContiguous + Reim4Convolution1Coeff + Reim4Convolution2Coeffs, + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef, +{ + let res: &mut VecZnxDft<&mut [u8], BE> = &mut res.to_mut(); + let a: &CnvPVecL<&[u8], BE> = &a.to_ref(); + let b: &CnvPVecR<&[u8], BE> = &b.to_ref(); + + let n: usize = res.n(); + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + let m: usize = n >> 1; + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + let b_size: usize = b.size(); + + let bound: usize = a_size + b_size - 1; + let min_size: usize = res_size.min(bound); + let offset: usize = res_offset.min(bound); + + let dst: &mut [f64] = res.raw_mut(); + let a_raw: &[f64] = a.raw(); + let b_raw: &[f64] = b.raw(); + + let mut a_idx: usize = a_col * n * a_size; + let mut b_idx: usize = b_col * n * b_size; + let a_offset: usize = a_size * 8; + let b_offset: usize = b_size * 8; + for blk_i in 0..m / 4 { + BE::reim4_convolution(tmp, min_size, offset, &a_raw[a_idx..], a_size, &b_raw[b_idx..], b_size); + BE::reim4_save_1blk_contiguous(m, min_size, blk_i, dst, tmp); + a_idx += a_offset; + b_idx += b_offset; + } + + for j in min_size..res_size { + res.zero_at(res_col, j); + } +} + +pub fn convolution_pairwise_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize { + convolution_apply_dft_tmp_bytes(res_size, a_size, b_size) + (a_size + b_size) * size_of::() * 8 +} + +#[allow(clippy::too_many_arguments)] +pub fn convolution_pairwise_apply_dft( + res: &mut R, + res_offset: usize, + res_col: usize, + a: &A, + b: &B, + col_i: usize, + col_j: usize, + tmp: &mut [f64], +) where + BE: Backend + + ReimZero + + ReimAdd + + ReimCopy + + Reim4Save1BlkContiguous + + Reim4Convolution1Coeff + + Reim4Convolution2Coeffs, + R: VecZnxDftToMut, + A: CnvPVecLToRef, + B: CnvPVecRToRef, +{ + if col_i == col_j { + convolution_apply_dft(res, res_offset, res_col, a, col_i, b, col_j, tmp); + return; + } + + let res: &mut VecZnxDft<&mut [u8], BE> = &mut res.to_mut(); + let a: &CnvPVecL<&[u8], BE> = &a.to_ref(); + let b: &CnvPVecR<&[u8], BE> = &b.to_ref(); + + let n: usize = res.n(); + let m: usize = n >> 1; + + assert_eq!(a.n(), n); + assert_eq!(b.n(), n); + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + let b_size: usize = b.size(); + + assert_eq!( + tmp.len(), + convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size) / size_of::() + ); + + let bound: usize = a_size + b_size - 1; + let min_size: usize = res_size.min(bound); + let offset: usize = res_offset.min(bound); + + let res_raw: &mut [f64] = res.raw_mut(); + let a_raw: &[f64] = a.raw(); + let b_raw: &[f64] = b.raw(); + + let a_row_size: usize = a_size * 8; + let b_row_size: usize = b_size * 8; + + let mut a0_idx: usize = col_i * n * a_size; + let mut a1_idx: usize = col_j * n * a_size; + let mut b0_idx: usize = col_i * n * b_size; + let mut b1_idx: usize = col_j * n * b_size; + + let (tmp_a, tmp) = tmp.split_at_mut(a_row_size); + let (tmp_b, tmp_res) = tmp.split_at_mut(b_row_size); + + for blk_i in 0..m / 4 { + let a0: &[f64] = &a_raw[a0_idx..]; + let a1: &[f64] = &a_raw[a1_idx..]; + let b0: &[f64] = &b_raw[b0_idx..]; + let b1: &[f64] = &b_raw[b1_idx..]; + + BE::reim_add(tmp_a, &a0[..a_row_size], &a1[..a_row_size]); + BE::reim_add(tmp_b, &b0[..b_row_size], &b1[..b_row_size]); + + BE::reim4_convolution(tmp_res, min_size, offset, tmp_a, a_size, tmp_b, b_size); + BE::reim4_save_1blk_contiguous(m, min_size, blk_i, res_raw, tmp_res); + + a0_idx += a_row_size; + a1_idx += a_row_size; + b0_idx += b_row_size; + b1_idx += b_row_size; + } + + for j in min_size..res_size { + res.zero_at(res_col, j); + } +} + +pub trait I64Extract1BlkContiguous { + fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]); +} + +pub trait I64Save1BlkContiguous { + fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]); +} + +#[inline(always)] +pub fn i64_extract_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) { + debug_assert!(blk < (n >> 3)); + debug_assert!(dst.len() >= rows * 8, "dst.len(): {} < rows*8: {}", dst.len(), 8 * rows); + + let offset: usize = offset + (blk << 3); + + // src = 8-values chunks spaced by n, dst = sequential 8-values chunks + let src_rows = src.chunks_exact(n).take(rows); + let dst_chunks = dst.chunks_exact_mut(8).take(rows); + + for (dst_chunk, src_row) in dst_chunks.zip(src_rows) { + dst_chunk.copy_from_slice(&src_row[offset..offset + 8]); + } +} + +#[inline(always)] +pub fn i64_save_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) { + debug_assert!(blk < (n >> 3)); + debug_assert!(src.len() >= rows * 8); + + let offset: usize = offset + (blk << 3); + + // dst = 4-values chunks spaced by m, src = sequential 4-values chunks + let dst_rows = dst.chunks_exact_mut(n).take(rows); + let src_chunks = src.chunks_exact(8).take(rows); + + for (dst_row, src_chunk) in dst_rows.zip(src_chunks) { + dst_row[offset..offset + 8].copy_from_slice(src_chunk); + } +} + +pub trait I64ConvolutionByConst1Coeff { + fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]); +} + +#[inline(always)] +pub fn i64_convolution_by_const_1coeff_ref(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) { + dst.fill(0); + + let b_size: usize = b.len(); + + if k >= a_size + b_size { + return; + } + let j_min: usize = k.saturating_sub(a_size - 1); + let j_max: usize = (k + 1).min(b_size); + + for j in j_min..j_max { + let ai: &[i64] = &a[8 * (k - j)..]; + let bi: i64 = b[j]; + + dst[0] = dst[0].wrapping_add(ai[0].wrapping_mul(bi)); + dst[1] = dst[1].wrapping_add(ai[1].wrapping_mul(bi)); + dst[2] = dst[2].wrapping_add(ai[2].wrapping_mul(bi)); + dst[3] = dst[3].wrapping_add(ai[3].wrapping_mul(bi)); + dst[4] = dst[4].wrapping_add(ai[4].wrapping_mul(bi)); + dst[5] = dst[5].wrapping_add(ai[5].wrapping_mul(bi)); + dst[6] = dst[6].wrapping_add(ai[6].wrapping_mul(bi)); + dst[7] = dst[7].wrapping_add(ai[7].wrapping_mul(bi)); + } +} + +#[inline(always)] +pub(crate) fn as_arr_i64(x: &[i64]) -> &[i64; size] { + debug_assert!(x.len() >= size, "x.len():{} < size:{}", x.len(), size); + unsafe { &*(x.as_ptr() as *const [i64; size]) } +} + +#[inline(always)] +pub(crate) fn as_arr_i64_mut(x: &mut [i64]) -> &mut [i64; size] { + debug_assert!(x.len() >= size, "x.len():{} < size:{}", x.len(), size); + unsafe { &mut *(x.as_mut_ptr() as *mut [i64; size]) } +} + +pub trait I64ConvolutionByConst2Coeffs { + fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]); +} + +#[inline(always)] +pub fn i64_convolution_by_const_2coeffs_ref(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) { + i64_convolution_by_const_1coeff_ref(k, as_arr_i64_mut(&mut dst[..8]), a, a_size, b); + i64_convolution_by_const_1coeff_ref(k + 1, as_arr_i64_mut(&mut dst[8..]), a, a_size, b); +} + +impl I64ConvolutionByConst for BE where Self: I64ConvolutionByConst1Coeff + I64ConvolutionByConst2Coeffs {} + +pub trait I64ConvolutionByConst +where + BE: I64ConvolutionByConst1Coeff + I64ConvolutionByConst2Coeffs, +{ + fn i64_convolution_by_const(dst: &mut [i64], dst_size: usize, offset: usize, a: &[i64], a_size: usize, b: &[i64]) { + assert!(a_size > 0); + + for k in (0..dst_size - 1).step_by(2) { + BE::i64_convolution_by_const_2coeffs(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b); + } + + if !dst_size.is_multiple_of(2) { + let k: usize = dst_size - 1; + BE::i64_convolution_by_const_1coeff(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b); + } + } +} diff --git a/poulpy-hal/src/reference/fft64/mod.rs b/poulpy-hal/src/reference/fft64/mod.rs index a1cf49a..9de3d97 100644 --- a/poulpy-hal/src/reference/fft64/mod.rs +++ b/poulpy-hal/src/reference/fft64/mod.rs @@ -1,3 +1,4 @@ +pub mod convolution; pub mod reim; pub mod reim4; pub mod svp; diff --git a/poulpy-hal/src/reference/fft64/reim/fft_ref.rs b/poulpy-hal/src/reference/fft64/reim/fft_ref.rs index 849a58e..e663266 100644 --- a/poulpy-hal/src/reference/fft64/reim/fft_ref.rs +++ b/poulpy-hal/src/reference/fft64/reim/fft_ref.rs @@ -12,26 +12,10 @@ pub fn fft_ref(m: usize, omg: &[R], data: &mut [R if m <= 16 { match m { 1 => {} - 2 => fft2_ref( - as_arr_mut::<2, R>(re), - as_arr_mut::<2, R>(im), - *as_arr::<2, R>(omg), - ), - 4 => fft4_ref( - as_arr_mut::<4, R>(re), - as_arr_mut::<4, R>(im), - *as_arr::<4, R>(omg), - ), - 8 => fft8_ref( - as_arr_mut::<8, R>(re), - as_arr_mut::<8, R>(im), - *as_arr::<8, R>(omg), - ), - 16 => fft16_ref( - as_arr_mut::<16, R>(re), - as_arr_mut::<16, R>(im), - *as_arr::<16, R>(omg), - ), + 2 => fft2_ref(as_arr_mut::<2, R>(re), as_arr_mut::<2, R>(im), *as_arr::<2, R>(omg)), + 4 => fft4_ref(as_arr_mut::<4, R>(re), as_arr_mut::<4, R>(im), *as_arr::<4, R>(omg)), + 8 => fft8_ref(as_arr_mut::<8, R>(re), as_arr_mut::<8, R>(im), *as_arr::<8, R>(omg)), + 16 => fft16_ref(as_arr_mut::<16, R>(re), as_arr_mut::<16, R>(im), *as_arr::<16, R>(omg)), _ => {} } } else if m <= 2048 { @@ -257,12 +241,7 @@ fn fft_bfs_16_ref(m: usize, re: &mut [R], im: &mu while mm > 16 { let h: usize = mm >> 2; for off in (0..m).step_by(mm) { - bitwiddle_fft_ref( - h, - &mut re[off..], - &mut im[off..], - as_arr::<4, R>(&omg[pos..]), - ); + bitwiddle_fft_ref(h, &mut re[off..], &mut im[off..], as_arr::<4, R>(&omg[pos..])); pos += 4; } mm = h @@ -289,14 +268,7 @@ fn twiddle_fft_ref(h: usize, re: &mut [R], im: &mut [R], let (im_lhs, im_rhs) = im.split_at_mut(h); for i in 0..h { - cplx_twiddle( - &mut re_lhs[i], - &mut im_lhs[i], - &mut re_rhs[i], - &mut im_rhs[i], - romg, - iomg, - ); + cplx_twiddle(&mut re_lhs[i], &mut im_lhs[i], &mut re_rhs[i], &mut im_rhs[i], romg, iomg); } } diff --git a/poulpy-hal/src/reference/fft64/reim/ifft_ref.rs b/poulpy-hal/src/reference/fft64/reim/ifft_ref.rs index e0fe8a2..0253cef 100644 --- a/poulpy-hal/src/reference/fft64/reim/ifft_ref.rs +++ b/poulpy-hal/src/reference/fft64/reim/ifft_ref.rs @@ -11,26 +11,10 @@ pub fn ifft_ref(m: usize, omg: &[R], data: &mut [ if m <= 16 { match m { 1 => {} - 2 => ifft2_ref( - as_arr_mut::<2, R>(re), - as_arr_mut::<2, R>(im), - *as_arr::<2, R>(omg), - ), - 4 => ifft4_ref( - as_arr_mut::<4, R>(re), - as_arr_mut::<4, R>(im), - *as_arr::<4, R>(omg), - ), - 8 => ifft8_ref( - as_arr_mut::<8, R>(re), - as_arr_mut::<8, R>(im), - *as_arr::<8, R>(omg), - ), - 16 => ifft16_ref( - as_arr_mut::<16, R>(re), - as_arr_mut::<16, R>(im), - *as_arr::<16, R>(omg), - ), + 2 => ifft2_ref(as_arr_mut::<2, R>(re), as_arr_mut::<2, R>(im), *as_arr::<2, R>(omg)), + 4 => ifft4_ref(as_arr_mut::<4, R>(re), as_arr_mut::<4, R>(im), *as_arr::<4, R>(omg)), + 8 => ifft8_ref(as_arr_mut::<8, R>(re), as_arr_mut::<8, R>(im), *as_arr::<8, R>(omg)), + 16 => ifft16_ref(as_arr_mut::<16, R>(re), as_arr_mut::<16, R>(im), *as_arr::<16, R>(omg)), _ => {} } } else if m <= 2048 { @@ -72,12 +56,7 @@ fn ifft_bfs_16_ref(m: usize, re: &mut [R], im: &mut [R], while h < m_half { let mm: usize = h << 2; for off in (0..m).step_by(mm) { - inv_bitwiddle_ifft_ref( - h, - &mut re[off..], - &mut im[off..], - as_arr::<4, R>(&omg[pos..]), - ); + inv_bitwiddle_ifft_ref(h, &mut re[off..], &mut im[off..], as_arr::<4, R>(&omg[pos..])); pos += 4; } h = mm; @@ -284,14 +263,7 @@ fn inv_twiddle_ifft_ref(h: usize, re: &mut [R], im: &mut let (im_lhs, im_rhs) = im.split_at_mut(h); for i in 0..h { - inv_twiddle( - &mut re_lhs[i], - &mut im_lhs[i], - &mut re_rhs[i], - &mut im_rhs[i], - romg, - iomg, - ); + inv_twiddle(&mut re_lhs[i], &mut im_lhs[i], &mut re_rhs[i], &mut im_rhs[i], romg, iomg); } } diff --git a/poulpy-hal/src/reference/fft64/reim/mod.rs b/poulpy-hal/src/reference/fft64/reim/mod.rs index 7decf3a..416cb63 100644 --- a/poulpy-hal/src/reference/fft64/reim/mod.rs +++ b/poulpy-hal/src/reference/fft64/reim/mod.rs @@ -35,7 +35,7 @@ pub use zero::*; #[inline(always)] pub(crate) fn as_arr(x: &[R]) -> &[R; size] { - debug_assert!(x.len() >= size); + debug_assert!(x.len() >= size, "x.len():{} < size:{}", x.len(), size); unsafe { &*(x.as_ptr() as *const [R; size]) } } diff --git a/poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs b/poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs index 24e8665..5ba09e0 100644 --- a/poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs +++ b/poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs @@ -1,15 +1,34 @@ -use crate::reference::fft64::reim::as_arr; +use crate::reference::fft64::reim::{as_arr, as_arr_mut, reim_zero_ref}; #[inline(always)] -pub fn reim4_extract_1blk_from_reim_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { - let mut offset: usize = blk << 2; - +pub fn reim4_extract_1blk_from_reim_contiguous_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { debug_assert!(blk < (m >> 2)); debug_assert!(dst.len() >= 2 * rows * 4); - for chunk in dst.chunks_exact_mut(4).take(2 * rows) { - chunk.copy_from_slice(&src[offset..offset + 4]); - offset += m + let offset: usize = blk << 2; + + // src = 4-values chunks spaced by m, dst = sequential 4-values chunks + let src_rows = src.chunks_exact(m).take(2 * rows); + let dst_chunks = dst.chunks_exact_mut(4).take(2 * rows); + + for (dst_chunk, src_row) in dst_chunks.zip(src_rows) { + dst_chunk.copy_from_slice(&src_row[offset..offset + 4]); + } +} + +#[inline(always)] +pub fn reim4_save_1blk_to_reim_contiguous_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + debug_assert!(blk < (m >> 2)); + debug_assert!(src.len() >= 2 * rows * 4); + + let offset: usize = blk << 2; + + // dst = 4-values chunks spaced by m, src = sequential 4-values chunks + let dst_rows = dst.chunks_exact_mut(m).take(2 * rows); + let src_chunks = src.chunks_exact(4).take(2 * rows); + + for (dst_row, src_chunk) in dst_rows.zip(src_chunks) { + dst_row[offset..offset + 4].copy_from_slice(src_chunk); } } @@ -53,7 +72,7 @@ pub fn reim4_save_2blk_to_reim_ref(m: usize, blk: usize, debug_assert!(dst.len() >= offset + 3 * m + 4); debug_assert!(src.len() >= 16); - let dst_off = &mut dst[offset..offset + 4]; + let dst_off: &mut [f64] = &mut dst[offset..offset + 4]; if OVERWRITE { dst_off.copy_from_slice(&src[0..4]); } else { @@ -64,7 +83,7 @@ pub fn reim4_save_2blk_to_reim_ref(m: usize, blk: usize, } offset += m; - let dst_off = &mut dst[offset..offset + 4]; + let dst_off: &mut [f64] = &mut dst[offset..offset + 4]; if OVERWRITE { dst_off.copy_from_slice(&src[4..8]); } else { @@ -76,7 +95,7 @@ pub fn reim4_save_2blk_to_reim_ref(m: usize, blk: usize, offset += m; - let dst_off = &mut dst[offset..offset + 4]; + let dst_off: &mut [f64] = &mut dst[offset..offset + 4]; if OVERWRITE { dst_off.copy_from_slice(&src[8..12]); } else { @@ -87,7 +106,7 @@ pub fn reim4_save_2blk_to_reim_ref(m: usize, blk: usize, } offset += m; - let dst_off = &mut dst[offset..offset + 4]; + let dst_off: &mut [f64] = &mut dst[offset..offset + 4]; if OVERWRITE { dst_off.copy_from_slice(&src[12..16]); } else { @@ -132,10 +151,7 @@ pub fn reim4_vec_mat2cols_product_ref( { assert_eq!(dst.len(), 16, "dst must have 16 doubles"); assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles"); - assert!( - v.len() >= nrows * 16, - "v must be at least nrows * 16 doubles" - ); + assert!(v.len() >= nrows * 16, "v must be at least nrows * 16 doubles"); } // zero accumulators @@ -161,11 +177,7 @@ pub fn reim4_vec_mat2cols_2ndcol_product_ref( ) { #[cfg(debug_assertions)] { - assert!( - dst.len() >= 8, - "dst must be at least 8 doubles but is {}", - dst.len() - ); + assert!(dst.len() >= 8, "dst must be at least 8 doubles but is {}", dst.len()); assert!( u.len() >= nrows * 8, "u must be at least nrows={} * 8 doubles but is {}", @@ -201,3 +213,57 @@ pub fn reim4_add_mul(dst: &mut [f64; 8], a: &[f64; 8], b: &[f64; 8]) { dst[k + 4] += ar * bi + ai * br; } } + +#[inline(always)] +pub fn reim4_convolution_1coeff_ref(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) { + reim_zero_ref(dst); + + if k >= a_size + b_size { + return; + } + let j_min: usize = k.saturating_sub(a_size - 1); + let j_max: usize = (k + 1).min(b_size); + + for j in j_min..j_max { + reim4_add_mul(dst, as_arr(&a[8 * (k - j)..]), as_arr(&b[8 * j..])); + } +} + +#[inline(always)] +pub fn reim4_convolution_2coeffs_ref(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) { + reim4_convolution_1coeff_ref(k, as_arr_mut(dst), a, a_size, b, b_size); + reim4_convolution_1coeff_ref(k + 1, as_arr_mut(&mut dst[8..]), a, a_size, b, b_size); +} + +#[inline(always)] +pub fn reim4_add_mul_b_real_const(dst: &mut [f64; 8], a: &[f64; 8], b: f64) { + for k in 0..4 { + let ar: f64 = a[k]; + let ai: f64 = a[k + 4]; + dst[k] += ar * b; + dst[k + 4] += ai * b; + } +} + +#[inline(always)] +pub fn reim4_convolution_by_real_const_1coeff_ref(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) { + reim_zero_ref(dst); + + let b_size: usize = b.len(); + + if k >= a_size + b_size { + return; + } + let j_min: usize = k.saturating_sub(a_size - 1); + let j_max: usize = (k + 1).min(b_size); + + for j in j_min..j_max { + reim4_add_mul_b_real_const(dst, as_arr(&a[8 * (k - j)..]), b[j]); + } +} + +#[inline(always)] +pub fn reim4_convolution_by_real_const_2coeffs_ref(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) { + reim4_convolution_by_real_const_1coeff_ref(k, as_arr_mut(dst), a, a_size, b); + reim4_convolution_by_real_const_1coeff_ref(k + 1, as_arr_mut(&mut dst[8..]), a, a_size, b); +} diff --git a/poulpy-hal/src/reference/fft64/reim4/mod.rs b/poulpy-hal/src/reference/fft64/reim4/mod.rs index 04bcf9c..2a7596b 100644 --- a/poulpy-hal/src/reference/fft64/reim4/mod.rs +++ b/poulpy-hal/src/reference/fft64/reim4/mod.rs @@ -2,8 +2,14 @@ mod arithmetic_ref; pub use arithmetic_ref::*; -pub trait Reim4Extract1Blk { - fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]); +use crate::{layouts::Backend, reference::fft64::reim::as_arr_mut}; + +pub trait Reim4Extract1BlkContiguous { + fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]); +} + +pub trait Reim4Save1BlkContiguous { + fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]); } pub trait Reim4Save1Blk { @@ -25,3 +31,63 @@ pub trait Reim4Mat2ColsProd { pub trait Reim4Mat2Cols2ndColProd { fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]); } + +pub trait Reim4Convolution1Coeff { + fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize); +} + +pub trait Reim4Convolution2Coeffs { + fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize); +} + +pub trait Reim4ConvolutionByRealConst1Coeff { + fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]); +} + +pub trait Reim4ConvolutionByRealConst2Coeffs { + fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]); +} + +impl Reim4Convolution for BE where Self: Reim4Convolution1Coeff + Reim4Convolution2Coeffs {} + +pub trait Reim4Convolution +where + BE: Reim4Convolution1Coeff + Reim4Convolution2Coeffs, +{ + fn reim4_convolution(dst: &mut [f64], dst_size: usize, offset: usize, a: &[f64], a_size: usize, b: &[f64], b_size: usize) { + assert!(a_size > 0); + assert!(b_size > 0); + + for k in (0..dst_size - 1).step_by(2) { + BE::reim4_convolution_2coeffs(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b, b_size); + } + + if !dst_size.is_multiple_of(2) { + let k: usize = dst_size - 1; + BE::reim4_convolution_1coeff(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b, b_size); + } + } +} + +impl Reim4ConvolutionByRealConst for BE where + Self: Reim4ConvolutionByRealConst1Coeff + Reim4ConvolutionByRealConst2Coeffs +{ +} + +pub trait Reim4ConvolutionByRealConst +where + BE: Reim4ConvolutionByRealConst1Coeff + Reim4ConvolutionByRealConst2Coeffs, +{ + fn reim4_convolution_by_real_const(dst: &mut [f64], dst_size: usize, offset: usize, a: &[f64], a_size: usize, b: &[f64]) { + assert!(a_size > 0); + + for k in (0..dst_size - 1).step_by(2) { + BE::reim4_convolution_by_real_const_2coeffs(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b); + } + + if !dst_size.is_multiple_of(2) { + let k: usize = dst_size - 1; + BE::reim4_convolution_by_real_const_1coeff(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b); + } + } +} diff --git a/poulpy-hal/src/reference/fft64/vec_znx_big.rs b/poulpy-hal/src/reference/fft64/vec_znx_big.rs index 64a643e..1a894f7 100644 --- a/poulpy-hal/src/reference/fft64/vec_znx_big.rs +++ b/poulpy-hal/src/reference/fft64/vec_znx_big.rs @@ -9,13 +9,14 @@ use crate::{ reference::{ vec_znx::{ vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate, - vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, + vec_znx_negate_inplace, vec_znx_normalize, vec_znx_normalize_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, + vec_znx_sub_negate_inplace, }, znx::{ ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNegate, - ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, - ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero, - znx_add_normal_f64_ref, + ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, + ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, + ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero, znx_add_normal_f64_ref, }, }, source::Source, @@ -231,15 +232,17 @@ where } pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize { - 2 * n * size_of::() + vec_znx_normalize_tmp_bytes(n) } +#[allow(clippy::too_many_arguments)] pub fn vec_znx_big_normalize( - res_base2k: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_base2k: usize, a: &A, + a_base2k: usize, a_col: usize, carry: &mut [i64], ) where @@ -256,7 +259,9 @@ pub fn vec_znx_big_normalize( + ZnxNormalizeFinalStep + ZnxNormalizeFirstStep + ZnxExtractDigitAddMul - + ZnxNormalizeDigit, + + ZnxNormalizeDigit + + ZnxNormalizeMiddleStepInplace + + ZnxNormalizeFinalStepInplace, { let a: VecZnxBig<&[u8], _> = a.to_ref(); let a_vznx: VecZnx<&[u8]> = VecZnx { @@ -267,7 +272,7 @@ pub fn vec_znx_big_normalize( max_size: a.max_size, }; - vec_znx_normalize::<_, _, BE>(res_base2k, res, res_col, a_base2k, &a_vznx, a_col, carry); + vec_znx_normalize::<_, _, BE>(res, res_base2k, res_offset, res_col, &a_vznx, a_base2k, a_col, carry); } pub fn vec_znx_big_add_normal_ref>( @@ -290,18 +295,13 @@ pub fn vec_znx_big_add_normal_ref>( let limb: usize = k.div_ceil(base2k) - 1; let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; - znx_add_normal_f64_ref( - res.at_mut(res_col, limb), - sigma * scale, - bound * scale, - source, - ) + znx_add_normal_f64_ref(res.at_mut(res_col, limb), sigma * scale, bound * scale, source) } pub fn test_vec_znx_big_add_normal(module: &Module) where Module: VecZnxBigAddNormal, - B: Backend + VecZnxBigAllocBytesImpl, + B: Backend + VecZnxBigAllocBytesImpl, { let n: usize = module.n(); let base2k: usize = 17; @@ -325,12 +325,7 @@ where }) } else { let std: f64 = a.stats(base2k, col_i).std() * k_f64; - assert!( - (std - sigma * sqrt2).abs() < 0.1, - "std={} ~!= {}", - std, - sigma * sqrt2 - ); + assert!((std - sigma * sqrt2).abs() < 0.1, "std={} ~!= {}", std, sigma * sqrt2); } }) }); diff --git a/poulpy-hal/src/reference/fft64/vmp.rs b/poulpy-hal/src/reference/fft64/vmp.rs index ac401b3..d5a781a 100644 --- a/poulpy-hal/src/reference/fft64/vmp.rs +++ b/poulpy-hal/src/reference/fft64/vmp.rs @@ -4,7 +4,10 @@ use crate::{ oep::VecZnxDftAllocBytesImpl, reference::fft64::{ reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero}, - reim4::{Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks}, + reim4::{ + Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, + Reim4Save2Blks, + }, vec_znx_dft::vec_znx_dft_apply, }, }; @@ -17,7 +20,7 @@ pub fn vmp_prepare_tmp_bytes(n: usize) -> usize { pub fn vmp_prepare(table: &ReimFFTTable, pmat: &mut R, mat: &A, tmp: &mut [f64]) where - BE: Backend + ReimDFTExecute, f64> + ReimFromZnx + Reim4Extract1Blk, + BE: Backend + ReimDFTExecute, f64> + ReimFromZnx + Reim4Extract1BlkContiguous, R: VmpPMatToMut, A: MatZnxToRef, { @@ -34,13 +37,7 @@ where res.cols_in(), a.cols_in() ); - assert_eq!( - res.rows(), - a.rows(), - "res.rows: {} != a.rows: {}", - res.rows(), - a.rows() - ); + assert_eq!(res.rows(), a.rows(), "res.rows: {} != a.rows: {}", res.rows(), a.rows()); assert_eq!( res.cols_out(), a.cols_out(), @@ -48,13 +45,7 @@ where res.cols_out(), a.cols_out() ); - assert_eq!( - res.size(), - a.size(), - "res.size: {} != a.size: {}", - res.size(), - a.size() - ); + assert_eq!(res.size(), a.size(), "res.size: {} != a.size: {}", res.size(), a.size()); } let nrows: usize = a.cols_in() * a.rows(); @@ -70,7 +61,7 @@ pub(crate) fn vmp_prepare_core( ncols: usize, tmp: &mut [f64], ) where - REIM: ReimDFTExecute, f64> + ReimFromZnx + Reim4Extract1Blk, + REIM: ReimDFTExecute, f64> + ReimFromZnx + Reim4Extract1BlkContiguous, { let m: usize = table.m(); let n: usize = m << 1; @@ -99,7 +90,7 @@ pub(crate) fn vmp_prepare_core( }; for blk_i in 0..m >> 2 { - REIM::reim4_extract_1blk(m, 1, blk_i, &mut dst[blk_i * offset..], tmp); + REIM::reim4_extract_1blk_contiguous(m, 1, blk_i, &mut dst[blk_i * offset..], tmp); } } } @@ -116,7 +107,7 @@ where + VecZnxDftAllocBytesImpl + ReimDFTExecute, f64> + ReimZero - + Reim4Extract1Blk + + Reim4Extract1BlkContiguous + Reim4Mat1ColProd + Reim4Mat2Cols2ndColProd + Reim4Mat2ColsProd @@ -168,7 +159,7 @@ pub fn vmp_apply_dft_to_dft(res: &mut R, a: &A, pmat: &M, tmp_bytes where BE: Backend + ReimZero - + Reim4Extract1Blk + + Reim4Extract1BlkContiguous + Reim4Mat1ColProd + Reim4Mat2Cols2ndColProd + Reim4Mat2ColsProd @@ -207,7 +198,7 @@ pub fn vmp_apply_dft_to_dft_add(res: &mut R, a: &A, pmat: &M, limb_ where BE: Backend + ReimZero - + Reim4Extract1Blk + + Reim4Extract1BlkContiguous + Reim4Mat1ColProd + Reim4Mat2Cols2ndColProd + Reim4Mat2ColsProd @@ -239,16 +230,7 @@ where let a_raw: &[f64] = a.raw(); let res_raw: &mut [f64] = res.raw_mut(); - vmp_apply_dft_to_dft_core::( - n, - res_raw, - a_raw, - pmat_raw, - limb_offset, - nrows, - ncols, - tmp_bytes, - ) + vmp_apply_dft_to_dft_core::(n, res_raw, a_raw, pmat_raw, limb_offset, nrows, ncols, tmp_bytes) } #[allow(clippy::too_many_arguments)] @@ -263,7 +245,7 @@ fn vmp_apply_dft_to_dft_core( tmp_bytes: &mut [f64], ) where REIM: ReimZero - + Reim4Extract1Blk + + Reim4Extract1BlkContiguous + Reim4Mat1ColProd + Reim4Mat2Cols2ndColProd + Reim4Mat2ColsProd @@ -299,41 +281,23 @@ fn vmp_apply_dft_to_dft_core( for blk_i in 0..(m >> 2) { let mat_blk_start: &[f64] = &pmat[blk_i * (8 * nrows * ncols)..]; - REIM::reim4_extract_1blk(m, row_max, blk_i, extracted_blk, a); + REIM::reim4_extract_1blk_contiguous(m, row_max, blk_i, extracted_blk, a); if limb_offset.is_multiple_of(2) { for (col_res, col_pmat) in (0..).step_by(2).zip((limb_offset..col_max - 1).step_by(2)) { let col_offset: usize = col_pmat * (8 * nrows); - REIM::reim4_mat2cols_prod( - row_max, - mat2cols_output, - extracted_blk, - &mat_blk_start[col_offset..], - ); + REIM::reim4_mat2cols_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]); REIM::reim4_save_2blks::(m, blk_i, &mut res[col_res * n..], mat2cols_output); } } else { let col_offset: usize = (limb_offset - 1) * (8 * nrows); - REIM::reim4_mat2cols_2ndcol_prod( - row_max, - mat2cols_output, - extracted_blk, - &mat_blk_start[col_offset..], - ); + REIM::reim4_mat2cols_2ndcol_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]); REIM::reim4_save_1blk::(m, blk_i, res, mat2cols_output); - for (col_res, col_pmat) in (1..) - .step_by(2) - .zip((limb_offset + 1..col_max - 1).step_by(2)) - { + for (col_res, col_pmat) in (1..).step_by(2).zip((limb_offset + 1..col_max - 1).step_by(2)) { let col_offset: usize = col_pmat * (8 * nrows); - REIM::reim4_mat2cols_prod( - row_max, - mat2cols_output, - extracted_blk, - &mat_blk_start[col_offset..], - ); + REIM::reim4_mat2cols_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]); REIM::reim4_save_2blks::(m, blk_i, &mut res[col_res * n..], mat2cols_output); } } @@ -344,26 +308,11 @@ fn vmp_apply_dft_to_dft_core( if last_col >= limb_offset { if ncols == col_max { - REIM::reim4_mat1col_prod( - row_max, - mat2cols_output, - extracted_blk, - &mat_blk_start[col_offset..], - ); + REIM::reim4_mat1col_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]); } else { - REIM::reim4_mat2cols_prod( - row_max, - mat2cols_output, - extracted_blk, - &mat_blk_start[col_offset..], - ); + REIM::reim4_mat2cols_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]); } - REIM::reim4_save_1blk::( - m, - blk_i, - &mut res[(last_col - limb_offset) * n..], - mat2cols_output, - ); + REIM::reim4_save_1blk::(m, blk_i, &mut res[(last_col - limb_offset) * n..], mat2cols_output); } } } diff --git a/poulpy-hal/src/reference/vec_znx/convolution.rs b/poulpy-hal/src/reference/vec_znx/convolution.rs deleted file mode 100644 index e69de29..0000000 diff --git a/poulpy-hal/src/reference/vec_znx/merge_rings.rs b/poulpy-hal/src/reference/vec_znx/merge_rings.rs index 76dca32..22d7fea 100644 --- a/poulpy-hal/src/reference/vec_znx/merge_rings.rs +++ b/poulpy-hal/src/reference/vec_znx/merge_rings.rs @@ -24,10 +24,7 @@ where { assert_eq!(tmp.len(), res.n()); - debug_assert!( - _n_out > _n_in, - "invalid a: output ring degree should be greater" - ); + debug_assert!(_n_out > _n_in, "invalid a: output ring degree should be greater"); a[1..].iter().for_each(|ai| { debug_assert_eq!( ai.to_ref().n(), diff --git a/poulpy-hal/src/reference/vec_znx/normalize.rs b/poulpy-hal/src/reference/vec_znx/normalize.rs index 021f57c..bbb10b7 100644 --- a/poulpy-hal/src/reference/vec_znx/normalize.rs +++ b/poulpy-hal/src/reference/vec_znx/normalize.rs @@ -14,15 +14,17 @@ use crate::{ }; pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize { - 2 * n * size_of::() + 3 * n * size_of::() } +#[allow(clippy::too_many_arguments)] pub fn vec_znx_normalize( - res_base2k: usize, res: &mut R, + res_base2k: usize, + res_offset: i64, res_col: usize, - a_base2k: usize, a: &A, + a_base2k: usize, a_col: usize, carry: &mut [i64], ) where @@ -38,14 +40,40 @@ pub fn vec_znx_normalize( + ZnxNormalizeFinalStep + ZnxNormalizeFirstStep + ZnxExtractDigitAddMul + + ZnxNormalizeMiddleStepInplace + + ZnxNormalizeFinalStepInplace + ZnxNormalizeDigit, +{ + match res_base2k == a_base2k { + true => vec_znx_normalize_inter_base2k::(res_base2k, res, res_offset, res_col, a, a_col, carry), + false => vec_znx_normalize_cross_base2k::(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry), + } +} + +fn vec_znx_normalize_inter_base2k( + base2k: usize, + res: &mut R, + res_offset: i64, + res_col: usize, + a: &A, + a_col: usize, + carry: &mut [i64], +) where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxZero + + ZnxNormalizeFirstStepCarryOnly + + ZnxNormalizeMiddleStepCarryOnly + + ZnxNormalizeMiddleStep + + ZnxNormalizeFinalStepInplace + + ZnxNormalizeMiddleStepInplace, { let mut res: VecZnx<&mut [u8]> = res.to_mut(); let a: VecZnx<&[u8]> = a.to_ref(); #[cfg(debug_assertions)] { - assert!(carry.len() >= 2 * res.n()); + assert!(carry.len() >= vec_znx_normalize_tmp_bytes(res.n()) / size_of::()); assert_eq!(res.n(), a.n()); } @@ -53,153 +81,323 @@ pub fn vec_znx_normalize( let res_size: usize = res.size(); let a_size: usize = a.size(); - let carry: &mut [i64] = &mut carry[..2 * n]; + let (carry, _) = carry.split_at_mut(n); - if res_base2k == a_base2k { - if a_size > res_size { - for j in (res_size..a_size).rev() { - if j == a_size - 1 { - ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry); - } else { - ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry); - } - } + let mut lsh: i64 = res_offset % base2k as i64; + let mut limbs_offset: i64 = res_offset / base2k as i64; - for j in (1..res_size).rev() { - ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); - } + // If res_offset is negative, makes it positive + // and corrects by adding an additional offset + // on the limbs. + if res_offset < 0 && lsh != 0 { + lsh = (lsh + base2k as i64) % (base2k as i64); + limbs_offset -= 1; + } - ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry); + let lsh_pos: usize = lsh as usize; + + let res_end: usize = (-limbs_offset).clamp(0, res_size as i64) as usize; + let res_start: usize = (a_size as i64 - limbs_offset).clamp(0, res_size as i64) as usize; + let a_end: usize = limbs_offset.clamp(0, a_size as i64) as usize; + let a_start: usize = (res_size as i64 + limbs_offset).clamp(0, a_size as i64) as usize; + + let a_out_range: usize = a_size.saturating_sub(a_start); + + // Computes the carry over the discarded limbs of a + for j in 0..a_out_range { + if j == 0 { + ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh_pos, a.at(a_col, a_size - j - 1), carry); } else { - for j in (0..a_size).rev() { - if j == a_size - 1 { - ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); - } else if j == 0 { - ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); - } else { - ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); - } - } - - for j in a_size..res_size { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - } + ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh_pos, a.at(a_col, a_size - j - 1), carry); } - } else { - let (a_norm, carry) = carry.split_at_mut(n); + } - // Relevant limbs of res - let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size); + // If no limbs were discarded, initialize carry to zero + if a_out_range == 0 { + ZNXARI::znx_zero(carry); + } - // Relevant limbs of a - let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size); + // Zeroes bottom limbs that will not be interacted with + for j in res_start..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } - // Get carry for limbs of a that have higher precision than res - for j in (a_min_size..a_size).rev() { - if j == a_size - 1 { - ZNXARI::znx_normalize_first_step_carry_only(a_base2k, 0, a.at(a_col, j), carry); - } else { - ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, 0, a.at(a_col, j), carry); + let mid_range: usize = a_start.saturating_sub(a_end); + + // Regular normalization over the overlapping limbs of res and a. + for j in 0..mid_range { + ZNXARI::znx_normalize_middle_step( + base2k, + lsh_pos, + res.at_mut(res_col, res_start - j - 1), + a.at(a_col, a_start - j - 1), + carry, + ); + } + + // Propagates the carry over the non-overlapping limbs between res and a + for j in 0..res_end { + ZNXARI::znx_zero(res.at_mut(res_col, res_end - j - 1)); + if j == res_end - 1 { + ZNXARI::znx_normalize_final_step_inplace(base2k, lsh_pos, res.at_mut(res_col, res_end - j - 1), carry); + } else { + ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh_pos, res.at_mut(res_col, res_end - j - 1), carry); + } + } +} + +#[allow(clippy::too_many_arguments)] +fn vec_znx_normalize_cross_base2k( + res: &mut R, + res_base2k: usize, + res_offset: i64, + res_col: usize, + a: &A, + a_base2k: usize, + a_col: usize, + carry: &mut [i64], +) where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxZero + + ZnxCopy + + ZnxAddInplace + + ZnxMulPowerOfTwoInplace + + ZnxNormalizeFirstStepCarryOnly + + ZnxNormalizeMiddleStepCarryOnly + + ZnxNormalizeMiddleStep + + ZnxNormalizeFinalStep + + ZnxNormalizeFirstStep + + ZnxExtractDigitAddMul + + ZnxNormalizeMiddleStepInplace + + ZnxNormalizeFinalStepInplace + + ZnxNormalizeDigit, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert!(carry.len() >= vec_znx_normalize_tmp_bytes(res.n()) / size_of::()); + assert_eq!(res.n(), a.n()); + } + + let n: usize = res.n(); + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let (a_norm, carry) = carry.split_at_mut(n); + let (res_carry, a_carry) = carry[..2 * n].split_at_mut(n); + ZNXARI::znx_zero(res_carry); + + // Total precision (in bits) that `a` and `res` can represent. + let a_tot_bits: usize = a_size * a_base2k; + let res_tot_bits: usize = res_size * res_base2k; + + // Derive intra-limb shift and cross-limb offset. + let mut lsh: i64 = res_offset % a_base2k as i64; + let mut limbs_offset: i64 = res_offset / a_base2k as i64; + + // If res_offset is negative, ensures it is positive + // and corrects by incrementing the cross-limb offset. + if res_offset < 0 && lsh != 0 { + lsh = (lsh + a_base2k as i64) % (a_base2k as i64); + limbs_offset -= 1; + } + + let lsh_pos: usize = lsh as usize; + + // Derive start/stop bit indexes of the overlap between `a` and `res` (after taking into account the offset).. + let res_end_bit: usize = (-limbs_offset * a_base2k as i64).clamp(0, res_tot_bits as i64) as usize; // Stop bit + let res_start_bit: usize = (a_tot_bits as i64 - limbs_offset * a_base2k as i64).clamp(0, res_tot_bits as i64) as usize; // Start bit + let a_end_bit: usize = (limbs_offset * a_base2k as i64).clamp(0, a_tot_bits as i64) as usize; // Stop bit + let a_start_bit: usize = (res_tot_bits as i64 + limbs_offset * a_base2k as i64).clamp(0, a_tot_bits as i64) as usize; // Start bit + + // Convert bits to limb indexes. + let res_end: usize = res_end_bit / res_base2k; + let res_start: usize = res_start_bit.div_ceil(res_base2k); + let a_end: usize = a_end_bit / a_base2k; + let a_start: usize = a_start_bit.div_ceil(a_base2k); + + // Zero all limbs of `res`. Unlike the simple case + // where `res_base2k` is equal to `a_base2k`, we also + // need to ensure that the limbs starting from `res_end` + // are zero. + for j in 0..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + + // Case where offset is positive and greater or equal + // to the precision of a. + if res_start == 0 { + return; + } + + // Limbs of `a` that have a greater precision than `res`. + let a_out_range: usize = a_size.saturating_sub(a_start); + + for j in 0..a_out_range { + if j == 0 { + ZNXARI::znx_normalize_first_step_carry_only(a_base2k, lsh_pos, a.at(a_col, a_size - j - 1), a_carry); + } else { + ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, lsh_pos, a.at(a_col, a_size - j - 1), a_carry); + } + } + + // Zero carry if the above loop didn't trigger. + if a_out_range == 0 { + ZNXARI::znx_zero(a_carry); + } + + // How much is left to accumulate to fill a limb of `res`. + let mut res_acc_left: usize = res_base2k; + + // Starting limb of `res`. + let mut res_limb: usize = res_start - 1; + + // How many limbs of `a` overlap with `res` (after taking into account the offset). + let mid_range: usize = a_start.saturating_sub(a_end); + + // Regular normalization over the overlapping limbs of res and a. + 'outer: for j in 0..mid_range { + let a_limb: usize = a_start - j - 1; + + // Current res & a limbs + let a_slice: &[i64] = a.at(a_col, a_limb); + + // Trackers: wow much of a_norm is left to + // be flushed on res. + let mut a_take_left: usize = a_base2k; + + // Normalizes the j-th limb of a and store the results into `a_norm``. + // This step is required to avoid overflow in the next step, + // which assumes that |a| is bounded by 2^{a_base2k -1} (i.e. normalized). + ZNXARI::znx_normalize_middle_step(a_base2k, lsh_pos, a_norm, a_slice, a_carry); + + // In the first iteration we need to match the precision `res` and `a`. + if j == 0 { + // Case where `a` has more precision than `res` (after taking into account the offset) + // + // For example: + // + // a: [x x x x x][x x x x x][x x x x x][x x x x x] + // res: [x x x x x x][x x x x x x][x x x x x x] + if !(a_tot_bits - a_start_bit).is_multiple_of(a_base2k) { + let take: usize = (a_tot_bits - a_start_bit) % a_base2k; + ZNXARI::znx_mul_power_of_two_inplace(-(take as i64), a_norm); + a_take_left -= take; + // Case where `res` has more precision than `a` (after taking into account the offset) + // + // For example: + // + // a: [x x x x x][x x x x x][x x x x x][x x x x x] + // res: [x x x x x x][x x x x x x][x x x x x x] + } else if !(res_tot_bits - res_start_bit).is_multiple_of(res_base2k) { + res_acc_left -= (res_tot_bits - res_start_bit) % res_base2k; } } - if a_min_size == a_size { - ZNXARI::znx_zero(carry); - } + // Extract bits of `a_norm` and accumulates them on res[res_limb] until + // res_base2k bits have been accumulated or until all bits of `a` are + // extracted. + 'inner: loop { + // Current limb of res + let res_slice: &mut [i64] = res.at_mut(res_col, res_limb); - // Maximum relevant precision of a - let a_prec: usize = a_min_size * a_base2k; + // We can take at most a_base2k bits + // but not more than what is left on a_norm or what is left to + // fully populate the current limb of res. + let a_take: usize = a_base2k.min(a_take_left).min(res_acc_left); - // Maximum relevant precision of res - let res_prec: usize = res_min_size * res_base2k; - - // Res limb index - let mut res_idx: usize = res_min_size - 1; - - // Trackers: wow much of res is left to be populated - // for the current limb. - let mut res_left: usize = res_base2k; - - for j in 0..res_size { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - } - - for j in (0..a_min_size).rev() { - // Trackers: wow much of a_norm is left to - // be flushed on res. - let mut a_left: usize = a_base2k; - - // Normalizes the j-th limb of a and store the results into a_norm. - // This step is required to avoid overflow in the next step, - // which assumes that |a| is bounded by 2^{a_base2k -1}. - if j != 0 { - ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry); - } else { - ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry); + if a_take != 0 { + // Extract `a_take` bits from a_norm and accumulates them on `res_slice`. + let scale: usize = res_base2k - res_acc_left; + ZNXARI::znx_extract_digit_addmul(a_take, scale, res_slice, a_norm); + a_take_left -= a_take; + res_acc_left -= a_take; } - // In the first iteration we need to match the precision of the input/output. - // If a_min_size * a_base2k > res_min_size * res_base2k - // then divround a_norm by the difference of precision and - // acts like if a_norm has already been partially consummed. - // Else acts like if res has been already populated - // by the difference. - if j == a_min_size - 1 { - if a_prec > res_prec { - ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm); - a_left -= a_prec - res_prec; - } else if res_prec > a_prec { - res_left -= res_prec - a_prec; - } - } + // If either: + // * At least `res_base2k` bits have been accumulated + // * We have reached the last limb of a + // Then: Flushes them onto res + if res_acc_left == 0 || a_limb == 0 { + // This case happens only if `res_offset` is negative. + // If `res_offset` is negative, we need to apply the offset BEFORE + // the normalization to ensure the `res-offset` overflowing bits of `a` + // are in the MSB of `res` instead of being discarded. + if a_limb == 0 && a_take_left == 0 { + // TODO: prove no overflow can happen here (should not intuitively) + ZNXARI::znx_add_inplace(a_carry, a_norm); - // Flushes a into res - loop { - // Selects the maximum amount of a that can be flushed - let a_take: usize = a_base2k.min(a_left).min(res_left); - - // Output limb - let res_slice: &mut [i64] = res.at_mut(res_col, res_idx); - - // Scaling of the value to flush - let lsh: usize = res_base2k - res_left; - - // Extract the bits to flush on the output and updates - // a_norm accordingly. - ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm); - - // Updates the trackers - a_left -= a_take; - res_left -= a_take; - - // If the current limb of res is full, - // then normalizes this limb and adds - // the carry on a_norm. - if res_left == 0 { - // Updates tracker - res_left += res_base2k; - - // Normalizes res and propagates the carry on a. - ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm); - - // If we reached the last limb of res breaks, - // but we might rerun the above loop if the - // base2k of a is much smaller than the base2k - // of res. - if res_idx == 0 { - ZNXARI::znx_add_inplace(carry, a_norm); - break; + // Usual case where for example + // a: [ overflow ][x x x x x][x x x x x][x x x x x][x x x x x] + // res: [x x x x x x][x x x x x x][x x x x x x][x x x x x x] + // + // where [overflow] are the overflowing bits of `a` (note that they are not a limb, but + // stored in a[0] & carry from a[1]) that are moved into the MSB of `res` due to the + // negative offset. + // + // In this case we populate what is left of `res_acc_left` using `a_carry` + // + // TODO: see if this can be simplified (e.g. just add). + if res_acc_left != 0 { + let scale: usize = res_base2k - res_acc_left; + ZNXARI::znx_extract_digit_addmul(res_acc_left, scale, res_slice, a_carry); } - // Else updates the limb index of res. - res_idx -= 1 + ZNXARI::znx_normalize_middle_step_inplace(res_base2k, 0, res_slice, res_carry); + + // Previous step might not consume all bits of a_carry + // TODO: prove no overflow can happen here + ZNXARI::znx_add_inplace(res_carry, a_carry); + + // We are done, so breaks out of the loop (yes we are at a[0], but + // this avoids possible over/under flows of tracking variables) + break 'outer; } - // If a_norm is exhausted, breaks the loop. - if a_left == 0 { - ZNXARI::znx_add_inplace(carry, a_norm); - break; + // If we reached the last limb of res + if res_limb == 0 { + break 'outer; } + + res_acc_left += res_base2k; + res_limb -= 1; + } + + // If a_norm is exhausted, breaks the inner loop. + if a_take_left == 0 { + ZNXARI::znx_add_inplace(a_carry, a_norm); + break 'inner; + } + } + } + + // This case will happen if offset is negative. + if res_end != 0 { + // If there are no overlapping limbs between `res` and `a` + // (can happen if offset is negative), then we propagate the + // carry of `a` on res. Note that the carry of `a` can be + // greater than the precision of res. + // + // For example with offset = -8: + // a carry a[0] a[1] a[2] a[3] + // a: [---------------------- ][x x x][x x x][x x x][x x x] + // b: [x x x x][x x x x ] + // res[0] res[1] + // + // If there are overlapping limbs between `res` and `a`, + // we can use `res_carry`, which contains the carry of propagating + // the shifted reconstruction of `a` in `res_base2k` along with + // the carry of a[0]. + let carry_to_use = if a_start == a_end { a_carry } else { res_carry }; + + for j in 0..res_end { + if j == res_end - 1 { + ZNXARI::znx_normalize_final_step_inplace(res_base2k, 0, res.at_mut(res_col, res_end - j - 1), carry_to_use); + } else { + ZNXARI::znx_normalize_middle_step_inplace(res_base2k, 0, res.at_mut(res_col, res_end - j - 1), carry_to_use); } } } @@ -229,6 +427,191 @@ where } } +#[test] +fn test_vec_znx_normalize_cross_base2k() { + let n: usize = 8; + + let mut carry: Vec = vec![0i64; vec_znx_normalize_tmp_bytes(n) / size_of::()]; + + use crate::reference::znx::ZnxRef; + use rug::ops::SubAssignRound; + use rug::{Float, float::Round}; + + let prec: usize = 128; + + for in_base2k in 1..=51 { + for out_base2k in 1..=51 { + for offset in [ + -(prec as i64), + -(prec as i64 - 1), + -(prec as i64 - in_base2k as i64), + -(in_base2k as i64 + 1), + in_base2k as i64, + -(in_base2k as i64 - 1), + 0, + (in_base2k as i64 - 1), + in_base2k as i64, + (in_base2k as i64 + 1), + (prec as i64 - in_base2k as i64), + (prec - 1) as i64, + prec as i64, + ] { + let mut source: Source = Source::new([1u8; 32]); + + let in_size: usize = prec.div_ceil(in_base2k); + let in_prec: u32 = (in_size * in_base2k) as u32; + + // Ensures no loss of precision (mostly for testing purpose) + let out_size: usize = (in_prec as usize).div_ceil(out_base2k); + + let out_prec: u32 = (out_size * out_base2k) as u32; + let min_prec: u32 = (in_size * in_base2k).min(out_size * out_base2k) as u32; + let mut want: VecZnx> = VecZnx::alloc(n, 1, in_size); + want.fill_uniform(60, &mut source); + + let mut have: VecZnx> = VecZnx::alloc(n, 1, out_size); + have.fill_uniform(60, &mut source); + vec_znx_normalize_cross_base2k::<_, _, ZnxRef>(&mut have, out_base2k, offset, 0, &want, in_base2k, 0, &mut carry); + + let mut data_have: Vec = (0..n).map(|_| Float::with_val(out_prec + 60, 0)).collect(); + let mut data_want: Vec = (0..n).map(|_| Float::with_val(in_prec + 60, 0)).collect(); + + have.decode_vec_float(out_base2k, 0, &mut data_have); + want.decode_vec_float(in_base2k, 0, &mut data_want); + + let scale: Float = Float::with_val(out_prec + 60, Float::u_pow_u(2, offset.unsigned_abs() as u32)); + + if offset > 0 { + for x in &mut data_want { + *x *= &scale; + *x %= 1; + } + } else if offset < 0 { + for x in &mut data_want { + *x /= &scale; + *x %= 1; + } + } else { + for x in &mut data_want { + *x %= 1; + } + } + + for x in &mut data_have { + if *x >= 0.5 { + *x -= 1; + } else if *x < -0.5 { + *x += 1; + } + } + + for x in &mut data_want { + if *x >= 0.5 { + *x -= 1; + } else if *x < -0.5 { + *x += 1; + } + } + + for i in 0..n { + //println!("i:{i:02} {} {}", data_want[i], data_have[i]); + + let mut err: Float = data_have[i].clone(); + err.sub_assign_round(&data_want[i], Round::Nearest); + err = err.abs(); + + let err_log2: f64 = err.clone().max(&Float::with_val(prec as u32, 1e-60)).log2().to_f64(); + + assert!(err_log2 <= -(min_prec as f64) + 1.0, "{} {}", err_log2, -(min_prec as f64)) + } + } + } + } +} + +#[test] +fn test_vec_znx_normalize_inter_base2k() { + let n: usize = 8; + + let mut carry: Vec = vec![0i64; vec_znx_normalize_tmp_bytes(n) / size_of::()]; + + use crate::reference::znx::ZnxRef; + use rug::ops::SubAssignRound; + use rug::{Float, float::Round}; + + let mut source: Source = Source::new([1u8; 32]); + + let prec: usize = 128; + let offset_range: i64 = prec as i64; + + for base2k in 1..=51 { + for offset in (-offset_range..=offset_range).step_by(base2k + 1) { + let size: usize = prec.div_ceil(base2k); + let out_prec: u32 = (size * base2k) as u32; + + // Fills "want" with uniform values + let mut want: VecZnx> = VecZnx::alloc(n, 1, size); + want.fill_uniform(60, &mut source); + + // Fills "have" with the shifted normalization of "want" + let mut have: VecZnx> = VecZnx::alloc(n, 1, size); + have.fill_uniform(60, &mut source); + vec_znx_normalize_inter_base2k::<_, _, ZnxRef>(base2k, &mut have, offset, 0, &want, 0, &mut carry); + + let mut data_have: Vec = (0..n).map(|_| Float::with_val(out_prec + 60, 0)).collect(); + let mut data_want: Vec = (0..n).map(|_| Float::with_val(out_prec + 60, 0)).collect(); + + have.decode_vec_float(base2k, 0, &mut data_have); + want.decode_vec_float(base2k, 0, &mut data_want); + + let scale: Float = Float::with_val(out_prec + 60, Float::u_pow_u(2, offset.unsigned_abs() as u32)); + + if offset > 0 { + for x in &mut data_want { + *x *= &scale; + *x %= 1; + } + } else if offset < 0 { + for x in &mut data_want { + *x /= &scale; + *x %= 1; + } + } else { + for x in &mut data_want { + *x %= 1; + } + } + + for x in &mut data_have { + if *x >= 0.5 { + *x -= 1; + } else if *x < -0.5 { + *x += 1; + } + } + + for x in &mut data_want { + if *x >= 0.5 { + *x -= 1; + } else if *x < -0.5 { + *x += 1; + } + } + + for i in 0..n { + //println!("i:{i:02} {} {}", data_want[i], data_have[i]); + + let mut err: Float = data_have[i].clone(); + err.sub_assign_round(&data_want[i], Round::Nearest); + err = err.abs(); + + let err_log2: f64 = err.clone().max(&Float::with_val(prec as u32, 1e-60)).log2().to_f64(); + + assert!(err_log2 <= -(out_prec as f64), "{} {}", err_log2, -(out_prec as f64)) + } + } + } +} pub fn bench_vec_znx_normalize(c: &mut Criterion, label: &str) where Module: VecZnxNormalize + ModuleNew + VecZnxNormalizeTmpBytes, @@ -261,10 +644,10 @@ where res.fill_uniform(50, &mut source); let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes()); - + let res_offset: i64 = 0; move || { for i in 0..cols { - module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow()); + module.vec_znx_normalize(&mut res, base2k, res_offset, i, &a, base2k, i, scratch.borrow()); } black_box(()); } @@ -326,71 +709,3 @@ where group.finish(); } - -#[test] -fn test_vec_znx_normalize_conv() { - let n: usize = 8; - - let mut carry: Vec = vec![0i64; 2 * n]; - - use crate::reference::znx::ZnxRef; - use rug::ops::SubAssignRound; - use rug::{Float, float::Round}; - - let mut source: Source = Source::new([1u8; 32]); - - let prec: usize = 128; - - let mut data: Vec = vec![0i128; n]; - - data.iter_mut().for_each(|x| *x = source.next_i128()); - - for start_base2k in 1..50 { - for end_base2k in 1..50 { - let end_size: usize = prec.div_ceil(end_base2k); - - let mut want: VecZnx> = VecZnx::alloc(n, 1, end_size); - want.encode_vec_i128(end_base2k, 0, prec, &data); - vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry); - - // Creates a temporary poly where encoding is in start_base2k - let mut tmp: VecZnx> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k)); - tmp.encode_vec_i128(start_base2k, 0, prec, &data); - - vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry); - - let mut data_tmp: Vec = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect(); - tmp.decode_vec_float(start_base2k, 0, &mut data_tmp); - - let mut have: VecZnx> = VecZnx::alloc(n, 1, end_size); - vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry); - - let out_prec: u32 = (end_size * end_base2k) as u32; - - let mut data_want: Vec = (0..n).map(|_| Float::with_val(out_prec, 0)).collect(); - let mut data_res: Vec = (0..n).map(|_| Float::with_val(out_prec, 0)).collect(); - - have.decode_vec_float(end_base2k, 0, &mut data_want); - want.decode_vec_float(end_base2k, 0, &mut data_res); - - for i in 0..n { - let mut err: Float = data_want[i].clone(); - err.sub_assign_round(&data_res[i], Round::Nearest); - err = err.abs(); - - let err_log2: f64 = err - .clone() - .max(&Float::with_val(prec as u32, 1e-60)) - .log2() - .to_f64(); - - assert!( - err_log2 <= -(out_prec as f64) + 1., - "{} {}", - err_log2, - -(out_prec as f64) + 1. - ) - } - } - } -} diff --git a/poulpy-hal/src/reference/vec_znx/sampling.rs b/poulpy-hal/src/reference/vec_znx/sampling.rs index d1e12eb..6b0bff0 100644 --- a/poulpy-hal/src/reference/vec_znx/sampling.rs +++ b/poulpy-hal/src/reference/vec_znx/sampling.rs @@ -34,12 +34,7 @@ pub fn vec_znx_fill_normal_ref( let limb: usize = k.div_ceil(base2k) - 1; let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; - znx_fill_normal_f64_ref( - res.at_mut(res_col, limb), - sigma * scale, - bound * scale, - source, - ) + znx_fill_normal_f64_ref(res.at_mut(res_col, limb), sigma * scale, bound * scale, source) } pub fn vec_znx_add_normal_ref( @@ -62,10 +57,5 @@ pub fn vec_znx_add_normal_ref( let limb: usize = k.div_ceil(base2k) - 1; let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; - znx_add_normal_f64_ref( - res.at_mut(res_col, limb), - sigma * scale, - bound * scale, - source, - ) + znx_add_normal_f64_ref(res.at_mut(res_col, limb), sigma * scale, bound * scale, source) } diff --git a/poulpy-hal/src/reference/vec_znx/shift.rs b/poulpy-hal/src/reference/vec_znx/shift.rs index a7366f8..a4d9170 100644 --- a/poulpy-hal/src/reference/vec_znx/shift.rs +++ b/poulpy-hal/src/reference/vec_znx/shift.rs @@ -5,13 +5,10 @@ use criterion::{BenchmarkId, Criterion}; use crate::{ api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxLsh, VecZnxLshInplace, VecZnxRsh, VecZnxRshInplace}, layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, - reference::{ - vec_znx::vec_znx_copy, - znx::{ - ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, - ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, - ZnxZero, - }, + reference::znx::{ + ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, + ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, + ZnxZero, }, source::Source, }; @@ -54,10 +51,7 @@ where (0..size - steps).for_each(|j| { let (lhs, rhs) = res_raw.split_at_mut(slice_size * (j + steps)); - ZNXARI::znx_copy( - &mut lhs[start + j * slice_size..end + j * slice_size], - &rhs[start..end], - ); + ZNXARI::znx_copy(&mut lhs[start + j * slice_size..end + j * slice_size], &rhs[start..end]); }); for j in size - steps..size { @@ -65,16 +59,13 @@ where } } - // Inplace normalization with left shift of k % base2k - if !k.is_multiple_of(base2k) { - for j in (0..size - steps).rev() { - if j == size - steps - 1 { - ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry); - } else if j == 0 { - ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry); - } else { - ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry); - } + for j in (0..size - steps).rev() { + if j == size - steps - 1 { + ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry); + } else if j == 0 { + ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry); } } } @@ -104,38 +95,13 @@ where // Simply a left shifted normalization of limbs // by k/base2k and intra-limb by base2k - k%base2k - if !k.is_multiple_of(base2k) { - for j in (0..min_size).rev() { - if j == min_size - 1 { - ZNXARI::znx_normalize_first_step( - base2k, - k_rem, - res.at_mut(res_col, j), - a.at(a_col, j + steps), - carry, - ); - } else if j == 0 { - ZNXARI::znx_normalize_final_step( - base2k, - k_rem, - res.at_mut(res_col, j), - a.at(a_col, j + steps), - carry, - ); - } else { - ZNXARI::znx_normalize_middle_step( - base2k, - k_rem, - res.at_mut(res_col, j), - a.at(a_col, j + steps), - carry, - ); - } - } - } else { - // If k % base2k = 0, then this is simply a copy. - for j in (0..min_size).rev() { - ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps)); + for j in (0..min_size).rev() { + if j == min_size - 1 { + ZNXARI::znx_normalize_first_step(base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), carry); + } else if j == 0 { + ZNXARI::znx_normalize_final_step(base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), carry); + } else { + ZNXARI::znx_normalize_middle_step(base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), carry); } } @@ -146,10 +112,10 @@ where } pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize { - n * size_of::() + 2 * n * size_of::() } -pub fn vec_znx_rsh_inplace(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) +pub fn vec_znx_rsh_inplace(base2k: usize, k: usize, res: &mut R, res_col: usize, tmp: &mut [i64]) where R: VecZnxToMut, ZNXARI: ZnxZero @@ -163,76 +129,48 @@ where { let mut res: VecZnx<&mut [u8]> = res.to_mut(); let n: usize = res.n(); - let cols: usize = res.cols(); + let size: usize = res.size(); let mut steps: usize = k / base2k; let k_rem: usize = k % base2k; - if k == 0 { - return; - } - - if steps >= size { - for j in 0..size { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - } - return; - } - - let start: usize = n * res_col; - let end: usize = start + n; - let slice_size: usize = n * cols; - if !k.is_multiple_of(base2k) { // We rsh by an additional base2k and then lsh by base2k-k // Allows to re-use efficient normalization code, avoids // avoids overflows & produce output that is normalized steps += 1; + } - // All limbs of a that would fall outside of the limbs of res are discarded, - // but the carry still need to be computed. - (size - steps..size).rev().for_each(|j| { - if j == size - 1 { - ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry); - } else { - ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry); - } - }); + let (carry, tmp) = tmp[..2 * n].split_at_mut(n); - // Continues with shifted normalization - let res_raw: &mut [i64] = res.raw_mut(); - (steps..size).rev().for_each(|j| { - let (lhs, rhs) = res_raw.split_at_mut(slice_size * j); - let rhs_slice: &mut [i64] = &mut rhs[start..end]; - let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end]; - ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry); - }); + let lsh: usize = (base2k - k_rem) % base2k; - // Propagates carry on the rest of the limbs of res - for j in (0..steps).rev() { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - if j == 0 { - ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry); - } else { - ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry); - } + // All limbs of a that would fall outside of the limbs of res are discarded, + // but the carry still need to be computed. + for j in 0..steps { + if j == 0 { + ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh, res.at(res_col, size - j - 1), carry); + } else { + ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh, res.at(res_col, size - j - 1), carry); } - } else { - // Shift by multiples of base2k - let res_raw: &mut [i64] = res.raw_mut(); - (steps..size).rev().for_each(|j| { - let (lhs, rhs) = res_raw.split_at_mut(slice_size * j); - ZNXARI::znx_copy( - &mut rhs[start..end], - &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end], - ); - }); + } - // Zeroes the top - (0..steps).for_each(|j| { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - }); + // Continues with shifted normalization + for j in 0..size - steps { + ZNXARI::znx_copy(tmp, res.at(res_col, size - steps - j - 1)); + ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh, tmp, carry); + ZNXARI::znx_copy(res.at_mut(res_col, size - j - 1), tmp); + } + + // Propagates carry on the rest of the limbs of res + for j in 0..steps { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + if j == 0 { + ZNXARI::znx_normalize_final_step_inplace(base2k, lsh, res.at_mut(res_col, steps - j - 1), carry); + } else { + ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh, res.at_mut(res_col, steps - j - 1), carry); + } } } @@ -259,90 +197,59 @@ where let mut steps: usize = k / base2k; let k_rem: usize = k % base2k; - if k == 0 { - vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col); - return; - } - - if steps >= res_size { - for j in 0..res_size { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - } - return; - } - if !k.is_multiple_of(base2k) { // We rsh by an additional base2k and then lsh by base2k-k // Allows to re-use efficient normalization code, avoids // avoids overflows & produce output that is normalized steps += 1; + } - // All limbs of a that are moved outside of the limbs of res are discarded, - // but the carry still need to be computed. - for j in (res_size..a_size + steps).rev() { - if j == a_size + steps - 1 { - ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry); - } else { - ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry); - } + let lsh: usize = (base2k - k_rem) % base2k; // 0 if k | base2k + let res_end: usize = res_size.min(steps); + let res_start: usize = res_size.min(a_size + steps); + let a_start: usize = a_size.min(res_size.saturating_sub(steps)); + + // All limbs of a that are moved outside of the limbs of res are discarded, + // but the carry still need to be computed. + let a_out_range: usize = a_size.saturating_sub(a_start); + + for j in 0..a_out_range { + if j == 0 { + ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh, a.at(a_col, a_size - j - 1), carry); + } else { + ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh, a.at(a_col, a_size - j - 1), carry); } + } - // Avoids over flow of limbs of res - let min_size: usize = res_size.min(a_size + steps); + if a_out_range == 0 { + ZNXARI::znx_zero(carry); + } - // Zeroes lower limbs of res if a_size + steps < res_size - (min_size..res_size).for_each(|j| { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - }); + // Zeroes lower limbs of res if a_size + steps < res_size + for j in 0..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } - // Continues with shifted normalization - for j in (steps..min_size).rev() { - // Case if no limb of a was previously discarded - if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 { - ZNXARI::znx_normalize_first_step( - base2k, - base2k - k_rem, - res.at_mut(res_col, j), - a.at(a_col, j - steps), - carry, - ); - } else { - ZNXARI::znx_normalize_middle_step( - base2k, - base2k - k_rem, - res.at_mut(res_col, j), - a.at(a_col, j - steps), - carry, - ); - } + // Continues with shifted normalization + let mid_range: usize = res_start.saturating_sub(res_end); + + for j in 0..mid_range { + ZNXARI::znx_normalize_middle_step( + base2k, + lsh, + res.at_mut(res_col, res_start - j - 1), + a.at(a_col, a_start - j - 1), + carry, + ); + } + + // Propagates carry on the rest of the limbs of res + for j in 0..res_end { + if j == res_end - 1 { + ZNXARI::znx_normalize_final_step_inplace(base2k, lsh, res.at_mut(res_col, res_end - j - 1), carry); + } else { + ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh, res.at_mut(res_col, res_end - j - 1), carry); } - - // Propagates carry on the rest of the limbs of res - for j in (0..steps).rev() { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - if j == 0 { - ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry); - } else { - ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry); - } - } - } else { - let min_size: usize = res_size.min(a_size + steps); - - // Zeroes the top - (0..steps).for_each(|j| { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - }); - - // Shift a into res, up to the maximum - for j in (steps..min_size).rev() { - ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j - steps)); - } - - // Zeroes bottom if a_size + steps < res_size - (min_size..res_size).for_each(|j| { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - }); } } @@ -373,7 +280,7 @@ where let mut a: VecZnx> = VecZnx::alloc(n, cols, size); let mut b: VecZnx> = VecZnx::alloc(n, cols, size); - let mut scratch: ScratchOwned = ScratchOwned::alloc(n * size_of::()); + let mut scratch: ScratchOwned = ScratchOwned::alloc(vec_znx_lsh_tmp_bytes(n)); // Fill a with random i64 a.fill_uniform(50, &mut source); @@ -423,7 +330,7 @@ where let mut a: VecZnx> = VecZnx::alloc(n, cols, size); let mut res: VecZnx> = VecZnx::alloc(n, cols, size); - let mut scratch: ScratchOwned = ScratchOwned::alloc(n * size_of::()); + let mut scratch: ScratchOwned = ScratchOwned::alloc(vec_znx_lsh_tmp_bytes(n)); // Fill a with random i64 a.fill_uniform(50, &mut source); @@ -473,7 +380,7 @@ where let mut a: VecZnx> = VecZnx::alloc(n, cols, size); let mut b: VecZnx> = VecZnx::alloc(n, cols, size); - let mut scratch: ScratchOwned = ScratchOwned::alloc(n * size_of::()); + let mut scratch: ScratchOwned = ScratchOwned::alloc(vec_znx_rsh_tmp_bytes(n)); // Fill a with random i64 a.fill_uniform(50, &mut source); @@ -523,7 +430,7 @@ where let mut a: VecZnx> = VecZnx::alloc(n, cols, size); let mut res: VecZnx> = VecZnx::alloc(n, cols, size); - let mut scratch: ScratchOwned = ScratchOwned::alloc(n * size_of::()); + let mut scratch: ScratchOwned = ScratchOwned::alloc(vec_znx_rsh_tmp_bytes(n)); // Fill a with random i64 a.fill_uniform(50, &mut source); @@ -552,8 +459,8 @@ mod tests { layouts::{FillUniform, VecZnx, ZnxView}, reference::{ vec_znx::{ - vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace, - vec_znx_sub_inplace, + vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_normalize_inplace, vec_znx_rsh, + vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_sub_inplace, }, znx::ZnxRef, }, @@ -572,7 +479,7 @@ mod tests { let mut source: Source = Source::new([0u8; 32]); - let mut carry: Vec = vec![0i64; n]; + let mut carry: Vec = vec![0i64; vec_znx_lsh_tmp_bytes(n) / size_of::()]; let base2k: usize = 50; @@ -604,7 +511,7 @@ mod tests { let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - let mut carry: Vec = vec![0i64; n]; + let mut carry: Vec = vec![0i64; vec_znx_rsh_tmp_bytes(n) / size_of::()]; let base2k: usize = 50; diff --git a/poulpy-hal/src/reference/vec_znx/split_ring.rs b/poulpy-hal/src/reference/vec_znx/split_ring.rs index f450f3a..5c4a645 100644 --- a/poulpy-hal/src/reference/vec_znx/split_ring.rs +++ b/poulpy-hal/src/reference/vec_znx/split_ring.rs @@ -22,10 +22,7 @@ where { assert_eq!(tmp.len(), a.n()); - assert!( - _n_out < _n_in, - "invalid a: output ring degree should be smaller" - ); + assert!(_n_out < _n_in, "invalid a: output ring degree should be smaller"); res[1..].iter_mut().for_each(|bi| { assert_eq!( diff --git a/poulpy-hal/src/reference/znx/automorphism.rs b/poulpy-hal/src/reference/znx/automorphism.rs index 6ef9e84..29ec0c8 100644 --- a/poulpy-hal/src/reference/znx/automorphism.rs +++ b/poulpy-hal/src/reference/znx/automorphism.rs @@ -12,10 +12,6 @@ pub fn znx_automorphism_ref(p: i64, res: &mut [i64], a: &[i64]) { res[0] = a[0]; for ai in a.iter().take(n).skip(1) { k = (k + p_2n) & mask; - if k < n { - res[k] = *ai - } else { - res[k - n] = -*ai - } + if k < n { res[k] = *ai } else { res[k - n] = -*ai } } } diff --git a/poulpy-hal/src/reference/znx/normalization.rs b/poulpy-hal/src/reference/znx/normalization.rs index 95100e4..154b810 100644 --- a/poulpy-hal/src/reference/znx/normalization.rs +++ b/poulpy-hal/src/reference/znx/normalization.rs @@ -33,9 +33,9 @@ pub fn znx_normalize_first_step_carry_only_ref(base2k: usize, lsh: usize, x: &[i *c = get_carry_i64(base2k, *x, get_digit_i64(base2k, *x)); }); } else { - let basek_lsh: usize = base2k - lsh; + let base2k_lsh: usize = base2k - lsh; x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { - *c = get_carry_i64(basek_lsh, *x, get_digit_i64(basek_lsh, *x)); + *c = get_carry_i64(base2k_lsh, *x, get_digit_i64(base2k_lsh, *x)); }); } } @@ -55,10 +55,10 @@ pub fn znx_normalize_first_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [ *x = digit; }); } else { - let basek_lsh: usize = base2k - lsh; + let base2k_lsh: usize = base2k - lsh; x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { - let digit: i64 = get_digit_i64(basek_lsh, *x); - *c = get_carry_i64(basek_lsh, *x, digit); + let digit: i64 = get_digit_i64(base2k_lsh, *x); + *c = get_carry_i64(base2k_lsh, *x, digit); *x = digit << lsh; }); } @@ -80,10 +80,10 @@ pub fn znx_normalize_first_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: *x = digit; }); } else { - let basek_lsh: usize = base2k - lsh; + let base2k_lsh: usize = base2k - lsh; izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { - let digit: i64 = get_digit_i64(basek_lsh, *a); - *c = get_carry_i64(basek_lsh, *a, digit); + let digit: i64 = get_digit_i64(base2k_lsh, *a); + *c = get_carry_i64(base2k_lsh, *a, digit); *x = digit << lsh; }); } @@ -104,10 +104,10 @@ pub fn znx_normalize_middle_step_carry_only_ref(base2k: usize, lsh: usize, x: &[ *c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c)); }); } else { - let basek_lsh: usize = base2k - lsh; + let base2k_lsh: usize = base2k - lsh; x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { - let digit: i64 = get_digit_i64(basek_lsh, *x); - let carry: i64 = get_carry_i64(basek_lsh, *x, digit); + let digit: i64 = get_digit_i64(base2k_lsh, *x); + let carry: i64 = get_carry_i64(base2k_lsh, *x, digit); let digit_plus_c: i64 = (digit << lsh) + *c; *c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c)); }); @@ -131,10 +131,10 @@ pub fn znx_normalize_middle_step_inplace_ref(base2k: usize, lsh: usize, x: &mut *c = carry + get_carry_i64(base2k, digit_plus_c, *x); }); } else { - let basek_lsh: usize = base2k - lsh; + let base2k_lsh: usize = base2k - lsh; x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { - let digit: i64 = get_digit_i64(basek_lsh, *x); - let carry: i64 = get_carry_i64(basek_lsh, *x, digit); + let digit: i64 = get_digit_i64(base2k_lsh, *x); + let carry: i64 = get_carry_i64(base2k_lsh, *x, digit); let digit_plus_c: i64 = (digit << lsh) + *c; *x = get_digit_i64(base2k, digit_plus_c); *c = carry + get_carry_i64(base2k, digit_plus_c, *x); @@ -178,10 +178,10 @@ pub fn znx_normalize_middle_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a *c = carry + get_carry_i64(base2k, digit_plus_c, *x); }); } else { - let basek_lsh: usize = base2k - lsh; + let base2k_lsh: usize = base2k - lsh; izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { - let digit: i64 = get_digit_i64(basek_lsh, *a); - let carry: i64 = get_carry_i64(basek_lsh, *a, digit); + let digit: i64 = get_digit_i64(base2k_lsh, *a); + let carry: i64 = get_carry_i64(base2k_lsh, *a, digit); let digit_plus_c: i64 = (digit << lsh) + *c; *x = get_digit_i64(base2k, digit_plus_c); *c = carry + get_carry_i64(base2k, digit_plus_c, *x); @@ -202,9 +202,9 @@ pub fn znx_normalize_final_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [ *x = get_digit_i64(base2k, get_digit_i64(base2k, *x) + *c); }); } else { - let basek_lsh: usize = base2k - lsh; + let base2k_lsh: usize = base2k - lsh; x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { - *x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *x) << lsh) + *c); + *x = get_digit_i64(base2k, (get_digit_i64(base2k_lsh, *x) << lsh) + *c); }); } } @@ -221,9 +221,9 @@ pub fn znx_normalize_final_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: *x = get_digit_i64(base2k, get_digit_i64(base2k, *a) + *c); }); } else { - let basek_lsh: usize = base2k - lsh; + let base2k_lsh: usize = base2k - lsh; izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { - *x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *a) << lsh) + *c); + *x = get_digit_i64(base2k, (get_digit_i64(base2k_lsh, *a) << lsh) + *c); }); } } diff --git a/poulpy-hal/src/test_suite/convolution.rs b/poulpy-hal/src/test_suite/convolution.rs index d175656..0cfc278 100644 --- a/poulpy-hal/src/test_suite/convolution.rs +++ b/poulpy-hal/src/test_suite/convolution.rs @@ -1,19 +1,80 @@ +use rand::RngCore; + use crate::{ api::{ - BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc, - VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace, + CnvPVecAlloc, Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxAdd, + VecZnxBigAlloc, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, + VecZnxNormalizeInplace, }, layouts::{ - Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, - ZnxViewMut, ZnxZero, + Backend, CnvPVecL, CnvPVecR, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, + ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, }; -pub fn test_bivariate_tensoring(module: &M) +pub fn test_convolution_by_const(module: &M) +where + M: ModuleN + Convolution + VecZnxBigNormalize + VecZnxNormalizeInplace + VecZnxBigAlloc, + Scratch: ScratchTakeBasic, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let mut source: Source = Source::new([0u8; 32]); + + let base2k: usize = 12; + + let a_cols: usize = 2; + let a_size: usize = 15; + let b_size: usize = 15; + let res_size: usize = a_size + b_size; + + let mut a: VecZnx> = VecZnx::alloc(module.n(), a_cols, a_size); + let mut b: VecZnx> = VecZnx::alloc(module.n(), 1, b_size); + + let mut res_want: VecZnx> = VecZnx::alloc(module.n(), 1, res_size); + let mut res_have: VecZnx> = VecZnx::alloc(module.n(), 1, res_size); + let mut res_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, res_size); + + a.fill_uniform(base2k, &mut source); + + let mut b_const = vec![0i64; b_size]; + let mask = (1 << base2k) - 1; + for (j, x) in b_const[..1].iter_mut().enumerate() { + let r = source.next_u64() & mask; + *x = ((r << (64 - base2k)) as i64) >> (64 - base2k); + b.at_mut(0, j)[0] = *x + } + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.cnv_by_const_apply_tmp_bytes(res_size, 0, a_size, b_size)); + + for a_col in 0..a.cols() { + for offset in 0..res_size { + module.cnv_by_const_apply(&mut res_big, offset, 0, &a, a_col, &b_const, scratch.borrow()); + module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow()); + + bivariate_convolution_naive( + module, + base2k, + (offset + 1) as i64, + &mut res_want, + 0, + &a, + a_col, + &b, + 0, + scratch.borrow(), + ); + + assert_eq!(res_want, res_have); + } + } +} + +pub fn test_convolution(module: &M) where M: ModuleN - + BivariateTensoring + + Convolution + + CnvPVecAlloc + VecZnxDftAlloc + VecZnxDftApply + VecZnxIdftApplyTmpA @@ -27,56 +88,199 @@ where let base2k: usize = 12; - let a_cols: usize = 3; - let b_cols: usize = 3; - let a_size: usize = 3; - let b_size: usize = 3; - let c_cols: usize = a_cols + b_cols - 1; - let c_size: usize = a_size + b_size; + let a_cols: usize = 2; + let b_cols: usize = 2; + let a_size: usize = 15; + let b_size: usize = 15; + let res_size: usize = a_size + b_size; let mut a: VecZnx> = VecZnx::alloc(module.n(), a_cols, a_size); let mut b: VecZnx> = VecZnx::alloc(module.n(), b_cols, b_size); - let mut c_want: VecZnx> = VecZnx::alloc(module.n(), c_cols, c_size); - let mut c_have: VecZnx> = VecZnx::alloc(module.n(), c_cols, c_size); - let mut c_have_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(c_cols, c_size); - let mut c_have_big: VecZnxBig, BE> = module.vec_znx_big_alloc(c_cols, c_size); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(module.convolution_tmp_bytes(b_size)); + let mut res_want: VecZnx> = VecZnx::alloc(module.n(), 1, res_size); + let mut res_have: VecZnx> = VecZnx::alloc(module.n(), 1, res_size); + let mut res_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(1, res_size); + let mut res_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, res_size); a.fill_uniform(base2k, &mut source); b.fill_uniform(base2k, &mut source); - let mut b_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(b_cols, b_size); - for i in 0..b.cols() { - module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i); + let mut a_prep: CnvPVecL, BE> = module.cnv_pvec_left_alloc(a_cols, a_size); + let mut b_prep: CnvPVecR, BE> = module.cnv_pvec_right_alloc(b_cols, b_size); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + module + .cnv_apply_dft_tmp_bytes(res_size, 0, a_size, b_size) + .max(module.cnv_prepare_left_tmp_bytes(res_size, a_size)) + .max(module.cnv_prepare_right_tmp_bytes(res_size, b_size)), + ); + + module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow()); + module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow()); + + for a_col in 0..a.cols() { + for b_col in 0..b.cols() { + for offset in 0..res_size { + module.cnv_apply_dft(&mut res_dft, offset, 0, &a_prep, a_col, &b_prep, b_col, scratch.borrow()); + + module.vec_znx_idft_apply_tmpa(&mut res_big, 0, &mut res_dft, 0); + module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow()); + + bivariate_convolution_naive( + module, + base2k, + (offset + 1) as i64, + &mut res_want, + 0, + &a, + a_col, + &b, + b_col, + scratch.borrow(), + ); + + assert_eq!(res_want, res_have); + } + } + } +} + +pub fn test_convolution_pairwise(module: &M) +where + M: ModuleN + + Convolution + + CnvPVecAlloc + + VecZnxDftAlloc + + VecZnxDftApply + + VecZnxIdftApplyTmpA + + VecZnxBigNormalize + + VecZnxNormalizeInplace + + VecZnxBigAlloc + + VecZnxAdd + + VecZnxCopy, + Scratch: ScratchTakeBasic, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let mut source: Source = Source::new([0u8; 32]); + + let base2k: usize = 12; + + let cols = 2; + let a_size: usize = 15; + let b_size: usize = 15; + let res_size: usize = a_size + b_size; + + let mut a: VecZnx> = VecZnx::alloc(module.n(), cols, a_size); + let mut b: VecZnx> = VecZnx::alloc(module.n(), cols, b_size); + let mut tmp_a: VecZnx> = VecZnx::alloc(module.n(), 1, a_size); + let mut tmp_b: VecZnx> = VecZnx::alloc(module.n(), 1, b_size); + + let mut res_want: VecZnx> = VecZnx::alloc(module.n(), 1, res_size); + let mut res_have: VecZnx> = VecZnx::alloc(module.n(), 1, res_size); + let mut res_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(1, res_size); + let mut res_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, res_size); + + a.fill_uniform(base2k, &mut source); + b.fill_uniform(base2k, &mut source); + + let mut a_prep: CnvPVecL, BE> = module.cnv_pvec_left_alloc(cols, a_size); + let mut b_prep: CnvPVecR, BE> = module.cnv_pvec_right_alloc(cols, b_size); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + module + .cnv_pairwise_apply_dft_tmp_bytes(res_size, 0, a_size, b_size) + .max(module.cnv_prepare_left_tmp_bytes(res_size, a_size)) + .max(module.cnv_prepare_right_tmp_bytes(res_size, b_size)), + ); + + module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow()); + module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow()); + + for col_i in 0..cols { + for col_j in 0..cols { + for offset in 0..res_size { + module.cnv_pairwise_apply_dft(&mut res_dft, offset, 0, &a_prep, &b_prep, col_i, col_j, scratch.borrow()); + + module.vec_znx_idft_apply_tmpa(&mut res_big, 0, &mut res_dft, 0); + module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow()); + + if col_i != col_j { + module.vec_znx_add(&mut tmp_a, 0, &a, col_i, &a, col_j); + module.vec_znx_add(&mut tmp_b, 0, &b, col_i, &b, col_j); + } else { + module.vec_znx_copy(&mut tmp_a, 0, &a, col_i); + module.vec_znx_copy(&mut tmp_b, 0, &b, col_j); + } + + bivariate_convolution_naive( + module, + base2k, + (offset + 1) as i64, + &mut res_want, + 0, + &tmp_a, + 0, + &tmp_b, + 0, + scratch.borrow(), + ); + + assert_eq!(res_want, res_have); + } + } + } +} + +#[allow(clippy::too_many_arguments)] +pub fn bivariate_convolution_naive( + module: &M, + base2k: usize, + k: i64, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + scratch: &mut Scratch, +) where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + M: VecZnxNormalizeInplace, + Scratch: TakeSlice, +{ + let res: &mut VecZnx<&mut [u8]> = &mut res.to_mut(); + let a: &VecZnx<&[u8]> = &a.to_ref(); + let b: &VecZnx<&[u8]> = &b.to_ref(); + + for j in 0..res.size() { + res.zero_at(res_col, j); } - for mut k in 0..(2 * c_size + 1) as i64 { - k -= c_size as i64; + for a_limb in 0..a.size() { + for b_limb in 0..b.size() { + let res_scale_abs = k.unsigned_abs() as usize; - module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow()); + let mut res_limb: usize = a_limb + b_limb + 1; - for i in 0..c_cols { - module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i); + if k <= 0 { + res_limb += res_scale_abs; + + if res_limb < res.size() { + negacyclic_convolution_naive_add(res.at_mut(res_col, res_limb), a.at(a_col, a_limb), b.at(b_col, b_limb)); + } + } else if res_limb >= res_scale_abs { + res_limb -= res_scale_abs; + + if res_limb < res.size() { + negacyclic_convolution_naive_add(res.at_mut(res_col, res_limb), a.at(a_col, a_limb), b.at(b_col, b_limb)); + } + } } - - for i in 0..c_cols { - module.vec_znx_big_normalize( - base2k, - &mut c_have, - i, - base2k, - &c_have_big, - i, - scratch.borrow(), - ); - } - - bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow()); - - assert_eq!(c_want, c_have); } + + module.vec_znx_normalize_inplace(base2k, res, res_col, scratch); } fn bivariate_tensoring_naive( @@ -154,3 +358,18 @@ fn negacyclic_convolution_naive_add(res: &mut [i64], a: &[i64], b: &[i64]) { } } } + +fn negacyclic_convolution_naive(res: &mut [i64], a: &[i64], b: &[i64]) { + let n: usize = res.len(); + res.fill(0); + for i in 0..n { + let ai: i64 = a[i]; + let lim: usize = n - i; + for j in 0..lim { + res[i + j] += ai * b[j]; + } + for j in lim..n { + res[i + j - n] -= ai * b[j]; + } + } +} diff --git a/poulpy-hal/src/test_suite/serialization.rs b/poulpy-hal/src/test_suite/serialization.rs index dacd656..e1652cf 100644 --- a/poulpy-hal/src/test_suite/serialization.rs +++ b/poulpy-hal/src/test_suite/serialization.rs @@ -29,10 +29,7 @@ where receiver.read_from(&mut reader).expect("read_from failed"); // Ensure serialization round-trip correctness - assert_eq!( - &original, &receiver, - "Deserialized object does not match the original" - ); + assert_eq!(&original, &receiver, "Deserialized object does not match the original"); } #[test] diff --git a/poulpy-hal/src/test_suite/svp.rs b/poulpy-hal/src/test_suite/svp.rs index e72dc5a..1ef64af 100644 --- a/poulpy-hal/src/test_suite/svp.rs +++ b/poulpy-hal/src/test_suite/svp.rs @@ -90,24 +90,8 @@ where let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); for j in 0..cols { - module_ref.vec_znx_big_normalize( - base2k, - &mut res_ref, - j, - base2k, - &res_big_ref, - j, - scratch_ref.borrow(), - ); - module_test.vec_znx_big_normalize( - base2k, - &mut res_test, - j, - base2k, - &res_big_test, - j, - scratch_test.borrow(), - ); + module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow()); + module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow()); } assert_eq!(res_ref, res_test); @@ -212,24 +196,8 @@ where let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); for j in 0..cols { - module_ref.vec_znx_big_normalize( - base2k, - &mut res_ref, - j, - base2k, - &res_big_ref, - j, - scratch_ref.borrow(), - ); - module_test.vec_znx_big_normalize( - base2k, - &mut res_test, - j, - base2k, - &res_big_test, - j, - scratch_test.borrow(), - ); + module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow()); + module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow()); } assert_eq!(res_ref, res_test); @@ -339,24 +307,8 @@ where let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); for j in 0..cols { - module_ref.vec_znx_big_normalize( - base2k, - &mut res_ref, - j, - base2k, - &res_big_ref, - j, - scratch_ref.borrow(), - ); - module_test.vec_znx_big_normalize( - base2k, - &mut res_test, - j, - base2k, - &res_big_test, - j, - scratch_test.borrow(), - ); + module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow()); + module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow()); } assert_eq!(res_ref, res_test); @@ -447,24 +399,8 @@ pub fn test_svp_apply_dft_to_dft_inplace( let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); for j in 0..cols { - module_ref.vec_znx_big_normalize( - base2k, - &mut res_ref, - j, - base2k, - &res_big_ref, - j, - scratch_ref.borrow(), - ); - module_test.vec_znx_big_normalize( - base2k, - &mut res_test, - j, - base2k, - &res_big_test, - j, - scratch_test.borrow(), - ); + module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow()); + module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow()); } assert_eq!(res_ref, res_test); diff --git a/poulpy-hal/src/test_suite/vec_znx.rs b/poulpy-hal/src/test_suite/vec_znx.rs index 058d0ec..577fc73 100644 --- a/poulpy-hal/src/test_suite/vec_znx.rs +++ b/poulpy-hal/src/test_suite/vec_znx.rs @@ -7,8 +7,9 @@ use crate::{ VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh, VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes, - VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, + VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, + VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, + VecZnxSwitchRing, }, layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut}, reference::znx::znx_copy_ref, @@ -341,10 +342,7 @@ where let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_merge_rings_tmp_bytes()); for a_size in [1, 2, 3, 4] { - let mut a: [VecZnx>; 2] = [ - VecZnx::alloc(n >> 1, cols, a_size), - VecZnx::alloc(n >> 1, cols, a_size), - ]; + let mut a: [VecZnx>; 2] = [VecZnx::alloc(n >> 1, cols, a_size), VecZnx::alloc(n >> 1, cols, a_size)]; a.iter_mut().for_each(|ai| { ai.fill_uniform(base2k, &mut source); @@ -549,26 +547,20 @@ where let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - // Set d to garbage - res_ref.fill_uniform(base2k, &mut source); - res_test.fill_uniform(base2k, &mut source); + for res_offset in -(base2k as i64)..=(base2k as i64) { + // Set d to garbage + res_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); - // Reference - for i in 0..cols { - module_ref.vec_znx_normalize(base2k, &mut res_ref, i, base2k, &a, i, scratch_ref.borrow()); - module_test.vec_znx_normalize( - base2k, - &mut res_test, - i, - base2k, - &a, - i, - scratch_test.borrow(), - ); + // Reference + for i in 0..cols { + module_ref.vec_znx_normalize(&mut res_ref, base2k, res_offset, i, &a, base2k, i, scratch_ref.borrow()); + module_test.vec_znx_normalize(&mut res_test, base2k, res_offset, i, &a, base2k, i, scratch_test.borrow()); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); } - - assert_eq!(a.digest_u64(), a_digest); - assert_eq!(res_ref, res_test); } } } @@ -718,10 +710,7 @@ where }) } else { let std: f64 = a.stats(base2k, col_i).std(); - assert!( - (std - one_12_sqrt).abs() < 0.01, - "std={std} ~!= {one_12_sqrt}", - ); + assert!((std - one_12_sqrt).abs() < 0.01, "std={std} ~!= {one_12_sqrt}",); } }) }); @@ -783,11 +772,7 @@ where }) } else { let std: f64 = a.stats(base2k, col_i).std() * k_f64; - assert!( - (std - sigma * sqrt2).abs() < 0.1, - "std={std} ~!= {}", - sigma * sqrt2 - ); + assert!((std - sigma * sqrt2).abs() < 0.1, "std={std} ~!= {}", sigma * sqrt2); } }) }); @@ -872,9 +857,9 @@ where pub fn test_vec_znx_rsh(base2k: usize, module_ref: &Module
, module_test: &Module) where - Module
: VecZnxRsh
+ VecZnxLshTmpBytes, + Module
: VecZnxRsh
+ VecZnxRshTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, - Module: VecZnxRsh + VecZnxLshTmpBytes, + Module: VecZnxRsh + VecZnxRshTmpBytes, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { assert_eq!(module_ref.n(), module_test.n()); @@ -882,8 +867,8 @@ where let mut source: Source = Source::new([0u8; 32]); let cols: usize = 2; - let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes()); - let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes()); + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes()); for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); @@ -914,9 +899,9 @@ where pub fn test_vec_znx_rsh_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where - Module
: VecZnxRshInplace
+ VecZnxLshTmpBytes, + Module
: VecZnxRshInplace
+ VecZnxRshTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, - Module: VecZnxRshInplace + VecZnxLshTmpBytes, + Module: VecZnxRshInplace + VecZnxRshTmpBytes, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { assert_eq!(module_ref.n(), module_test.n()); @@ -924,8 +909,8 @@ where let mut source: Source = Source::new([0u8; 32]); let cols: usize = 2; - let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes()); - let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes()); + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes()); for res_size in [1, 2, 3, 4] { for k in 0..base2k * res_size { @@ -966,15 +951,11 @@ where let a_digest = a.digest_u64(); for res_size in [1, 2, 3, 4] { - let mut res_ref: [VecZnx>; 2] = [ - VecZnx::alloc(n >> 1, cols, res_size), - VecZnx::alloc(n >> 1, cols, res_size), - ]; + let mut res_ref: [VecZnx>; 2] = + [VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)]; - let mut res_test: [VecZnx>; 2] = [ - VecZnx::alloc(n >> 1, cols, res_size), - VecZnx::alloc(n >> 1, cols, res_size), - ]; + let mut res_test: [VecZnx>; 2] = + [VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)]; res_ref.iter_mut().for_each(|ri| { ri.fill_uniform(base2k, &mut source); diff --git a/poulpy-hal/src/test_suite/vec_znx_big.rs b/poulpy-hal/src/test_suite/vec_znx_big.rs index 2de43ad..15a0648 100644 --- a/poulpy-hal/src/test_suite/vec_znx_big.rs +++ b/poulpy-hal/src/test_suite/vec_znx_big.rs @@ -93,20 +93,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -188,20 +190,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -279,20 +283,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -367,20 +373,22 @@ pub fn test_vec_znx_big_add_small_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -459,20 +467,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -546,20 +556,22 @@ pub fn test_vec_znx_big_automorphism_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -631,20 +643,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -709,20 +723,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -782,36 +798,40 @@ where let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - // Set d to garbage - source.fill_bytes(res_ref.data_mut()); - source.fill_bytes(res_test.data_mut()); + for res_offset in -(base2k as i64)..=(base2k as i64) { + // Set d to garbage + source.fill_bytes(res_ref.data_mut()); + source.fill_bytes(res_test.data_mut()); - // Reference - for j in 0..cols { - module_ref.vec_znx_big_normalize( - base2k, - &mut res_ref, - j, - base2k, - &a_ref, - j, - scratch_ref.borrow(), - ); - module_test.vec_znx_big_normalize( - base2k, - &mut res_test, - j, - base2k, - &a_test, - j, - scratch_test.borrow(), - ); + // Reference + for j in 0..cols { + module_ref.vec_znx_big_normalize( + &mut res_ref, + base2k, + res_offset, + j, + &a_ref, + base2k, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + &mut res_test, + base2k, + res_offset, + j, + &a_test, + base2k, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + + assert_eq!(res_ref, res_test); } - - assert_eq!(a_ref.digest_u64(), a_ref_digest); - assert_eq!(a_test.digest_u64(), a_test_digest); - - assert_eq!(res_ref, res_test); } } } @@ -891,20 +911,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -986,20 +1008,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -1083,20 +1107,22 @@ pub fn test_vec_znx_big_sub_negate_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -1180,20 +1206,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -1278,20 +1306,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -1366,20 +1396,22 @@ pub fn test_vec_znx_big_sub_small_a_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -1427,55 +1459,59 @@ pub fn test_vec_znx_big_sub_small_b_inplace( let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { - let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(base2k, &mut source); + for res_offset in -(base2k as i64)..=(base2k as i64) { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(base2k, &mut source); - let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); - let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); - for j in 0..cols { - module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); - module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); + module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + } + + for i in 0..cols { + module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i); + module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + &mut res_small_ref, + base2k, + res_offset, + j, + &res_big_ref, + base2k, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + &mut res_small_test, + base2k, + res_offset, + j, + &res_big_test, + base2k, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); } - - for i in 0..cols { - module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i); - module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i); - } - - assert_eq!(a.digest_u64(), a_digest); - - let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); - let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - - let res_ref_digest: u64 = res_big_ref.digest_u64(); - let res_test_digest: u64 = res_big_test.digest_u64(); - - for j in 0..cols { - module_ref.vec_znx_big_normalize( - base2k, - &mut res_small_ref, - j, - base2k, - &res_big_ref, - j, - scratch_ref.borrow(), - ); - module_test.vec_znx_big_normalize( - base2k, - &mut res_small_test, - j, - base2k, - &res_big_test, - j, - scratch_test.borrow(), - ); - } - - assert_eq!(res_big_ref.digest_u64(), res_ref_digest); - assert_eq!(res_big_test.digest_u64(), res_test_digest); - - assert_eq!(res_small_ref, res_small_test); } } } diff --git a/poulpy-hal/src/test_suite/vec_znx_dft.rs b/poulpy-hal/src/test_suite/vec_znx_dft.rs index 87e2992..aa2e261 100644 --- a/poulpy-hal/src/test_suite/vec_znx_dft.rs +++ b/poulpy-hal/src/test_suite/vec_znx_dft.rs @@ -102,20 +102,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -208,20 +210,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -311,20 +315,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -392,13 +398,7 @@ where for j in 0..cols { module_ref.vec_znx_idft_apply(&mut res_big_ref, j, &res_dft_ref, j, scratch_ref.borrow()); - module_test.vec_znx_idft_apply( - &mut res_big_test, - j, - &res_dft_test, - j, - scratch_test.borrow(), - ); + module_test.vec_znx_idft_apply(&mut res_big_test, j, &res_dft_test, j, scratch_test.borrow()); } assert_eq!(res_dft_ref.digest_u64(), res_dft_ref_digest); @@ -412,20 +412,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -502,20 +504,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -589,20 +593,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -709,20 +715,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -815,20 +823,22 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -923,20 +933,22 @@ pub fn test_vec_znx_dft_sub_negate_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); diff --git a/poulpy-hal/src/test_suite/vmp.rs b/poulpy-hal/src/test_suite/vmp.rs index 2a7f720..7a71a31 100644 --- a/poulpy-hal/src/test_suite/vmp.rs +++ b/poulpy-hal/src/test_suite/vmp.rs @@ -90,20 +90,22 @@ where for j in 0..cols_out { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -205,18 +207,8 @@ where source.fill_bytes(res_dft_ref.data_mut()); source.fill_bytes(res_dft_test.data_mut()); - module_ref.vmp_apply_dft_to_dft( - &mut res_dft_ref, - &a_dft_ref, - &pmat_ref, - scratch_ref.borrow(), - ); - module_test.vmp_apply_dft_to_dft( - &mut res_dft_test, - &a_dft_test, - &pmat_test, - scratch_test.borrow(), - ); + module_ref.vmp_apply_dft_to_dft(&mut res_dft_ref, &a_dft_ref, &pmat_ref, scratch_ref.borrow()); + module_test.vmp_apply_dft_to_dft(&mut res_dft_test, &a_dft_test, &pmat_test, scratch_test.borrow()); let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); @@ -229,20 +221,22 @@ where for j in 0..cols_out { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); @@ -379,20 +373,22 @@ where for j in 0..cols_out { module_ref.vec_znx_big_normalize( - base2k, &mut res_small_ref, - j, base2k, + 0, + j, &res_big_ref, + base2k, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - base2k, &mut res_small_test, - j, base2k, + 0, + j, &res_big_test, + base2k, j, scratch_test.borrow(), ); diff --git a/poulpy-schemes/benches/bdd_arithmetic.rs b/poulpy-schemes/benches/bdd_arithmetic.rs index e4b9e9c..60456b8 100644 --- a/poulpy-schemes/benches/bdd_arithmetic.rs +++ b/poulpy-schemes/benches/bdd_arithmetic.rs @@ -91,14 +91,7 @@ where // Circuit bootstrapping evaluation key let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(¶ms.bdd_layout.cbt_layout); - cbt_key.encrypt_sk( - &module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + cbt_key.encrypt_sk(&module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); let mut cbt_key_prepared: CircuitBootstrappingKeyPrepared, BRA, BE> = CircuitBootstrappingKeyPrepared::alloc_from_infos(&module, ¶ms.bdd_layout.cbt_layout); @@ -108,14 +101,7 @@ where sk_glwe_prepared.prepare(&module, &sk_glwe); let mut bdd_key: BDDKey, BRA> = BDDKey::alloc_from_infos(¶ms.bdd_layout); - bdd_key.encrypt_sk( - &module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + bdd_key.encrypt_sk(&module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); let input_a = 255_u32; let input_b = 30_u32; @@ -332,95 +318,45 @@ where }; // Benchmark each operation - bench_operation::( - &mut group, - ¶ms, - "add", - |c_enc, module, a, b, key, scratch| { - c_enc.add(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "add", |c_enc, module, a, b, key, scratch| { + c_enc.add(module, a, b, key, scratch); + }); - bench_operation::( - &mut group, - ¶ms, - "sub", - |c_enc, module, a, b, key, scratch| { - c_enc.sub(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "sub", |c_enc, module, a, b, key, scratch| { + c_enc.sub(module, a, b, key, scratch); + }); - bench_operation::( - &mut group, - ¶ms, - "sll", - |c_enc, module, a, b, key, scratch| { - c_enc.sll(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "sll", |c_enc, module, a, b, key, scratch| { + c_enc.sll(module, a, b, key, scratch); + }); - bench_operation::( - &mut group, - ¶ms, - "sra", - |c_enc, module, a, b, key, scratch| { - c_enc.sra(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "sra", |c_enc, module, a, b, key, scratch| { + c_enc.sra(module, a, b, key, scratch); + }); - bench_operation::( - &mut group, - ¶ms, - "srl", - |c_enc, module, a, b, key, scratch| { - c_enc.srl(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "srl", |c_enc, module, a, b, key, scratch| { + c_enc.srl(module, a, b, key, scratch); + }); - bench_operation::( - &mut group, - ¶ms, - "slt", - |c_enc, module, a, b, key, scratch| { - c_enc.slt(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "slt", |c_enc, module, a, b, key, scratch| { + c_enc.slt(module, a, b, key, scratch); + }); - bench_operation::( - &mut group, - ¶ms, - "sltu", - |c_enc, module, a, b, key, scratch| { - c_enc.sltu(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "sltu", |c_enc, module, a, b, key, scratch| { + c_enc.sltu(module, a, b, key, scratch); + }); - bench_operation::( - &mut group, - ¶ms, - "or", - |c_enc, module, a, b, key, scratch| { - c_enc.or(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "or", |c_enc, module, a, b, key, scratch| { + c_enc.or(module, a, b, key, scratch); + }); - bench_operation::( - &mut group, - ¶ms, - "and", - |c_enc, module, a, b, key, scratch| { - c_enc.and(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "and", |c_enc, module, a, b, key, scratch| { + c_enc.and(module, a, b, key, scratch); + }); - bench_operation::( - &mut group, - ¶ms, - "xor", - |c_enc, module, a, b, key, scratch| { - c_enc.xor(module, a, b, key, scratch); - }, - ); + bench_operation::(&mut group, ¶ms, "xor", |c_enc, module, a, b, key, scratch| { + c_enc.xor(module, a, b, key, scratch); + }); group.finish(); } diff --git a/poulpy-schemes/benches/bdd_prepare.rs b/poulpy-schemes/benches/bdd_prepare.rs index e3eccc9..a75a4f1 100644 --- a/poulpy-schemes/benches/bdd_prepare.rs +++ b/poulpy-schemes/benches/bdd_prepare.rs @@ -98,14 +98,7 @@ where sk_glwe_prepared.prepare(&module, &sk_glwe); let mut bdd_key: BDDKey, BRA> = BDDKey::alloc_from_infos(¶ms.bdd_layout); - bdd_key.encrypt_sk( - &module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + bdd_key.encrypt_sk(&module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); let input_a = 255_u32; diff --git a/poulpy-schemes/benches/blind_rotate.rs b/poulpy-schemes/benches/blind_rotate.rs new file mode 100644 index 0000000..27a1793 --- /dev/null +++ b/poulpy-schemes/benches/blind_rotate.rs @@ -0,0 +1,149 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use poulpy_core::{ + GLWEDecrypt, LWEEncryptSk, ScratchTakeCore, + layouts::{ + Base2K, Dnum, GLWE, GLWELayout, GLWESecret, GLWESecretPrepared, GLWESecretPreparedFactory, LWE, LWEInfos, LWELayout, + LWESecret, TorusPrecision, + }, +}; + +#[cfg(all(feature = "enable-avx", target_arch = "x86_64"))] +pub use poulpy_cpu_avx::FFT64Avx as BackendImpl; + +#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64")))] +pub use poulpy_cpu_ref::FFT64Ref as BackendImpl; + +use poulpy_hal::{ + api::{ModuleN, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, FillUniform, Module, Scratch, ScratchOwned}, + source::Source, +}; +use poulpy_schemes::bin_fhe::blind_rotation::{ + BlindRotationAlgo, BlindRotationExecute, BlindRotationKey, BlindRotationKeyEncryptSk, BlindRotationKeyFactory, + BlindRotationKeyLayout, BlindRotationKeyPrepared, BlindRotationKeyPreparedFactory, CGGI, LookUpTableLayout, LookupTable, + LookupTableFactory, +}; + +pub fn benc_blind_rotate(c: &mut Criterion, label: &str) +where + Module: ModuleN + + ModuleNew + + BlindRotationKeyEncryptSk + + BlindRotationKeyPreparedFactory + + BlindRotationExecute + + LookupTableFactory + + GLWESecretPreparedFactory + + GLWEDecrypt + + LWEEncryptSk, + BlindRotationKey, BRA>: BlindRotationKeyFactory, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + let group_name: String = format!("blind_rotate::{label}"); + + let mut group = c.benchmark_group(group_name); + + let n_glwe: usize = 512; + let n_lwe: usize = 687; + let rank: usize = 3; + let block_size: usize = 3; + let extension_factor: usize = 1; + + let log_message_modulus: usize = 2; + let message_modulus: usize = 1 << log_message_modulus; + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 24); + + let module: Module = Module::::new(n_glwe as u64); + + let mut source_xs: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([2u8; 32]); + let mut source_xa: Source = Source::new([1u8; 32]); + + let brk_infos: BlindRotationKeyLayout = BlindRotationKeyLayout { + n_glwe: n_glwe.into(), + n_lwe: n_lwe.into(), + base2k: Base2K(18), + k: TorusPrecision(36), + dnum: Dnum(1), + rank: rank.into(), + }; + + let glwe_infos: GLWELayout = GLWELayout { + n: n_glwe.into(), + base2k: Base2K(18), + k: TorusPrecision(18), + rank: rank.into(), + }; + + let lwe_infos: LWELayout = LWELayout { + n: n_lwe.into(), + k: TorusPrecision(18), + base2k: Base2K(18), + }; + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_glwe_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(&module, &glwe_infos); + sk_glwe_dft.prepare(&module, &sk_glwe); + + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); + sk_lwe.fill_binary_block(block_size, &mut source_xs); + + let mut brk: BlindRotationKey, BRA> = BlindRotationKey::, BRA>::alloc(&brk_infos); + + brk.encrypt_sk( + &module, + &sk_glwe_dft, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut brk_prepared: BlindRotationKeyPrepared, BRA, BE> = BlindRotationKeyPrepared::alloc(&module, &brk); + brk_prepared.prepare(&module, &brk, scratch.borrow()); + + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_infos); + res.data_mut().fill_uniform(glwe_infos.base2k().as_usize(), &mut source_xa); + let mut lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); + lwe.data_mut().fill_uniform(lwe_infos.base2k().as_usize(), &mut source_xa); + + let f = |x: i64| -> i64 { 2 * x + 1 }; + + let mut f_vec: Vec = vec![0i64; message_modulus]; + f_vec.iter_mut().enumerate().for_each(|(i, x)| *x = f(i as i64)); + + let lut_infos = LookUpTableLayout { + n: module.n().into(), + extension_factor, + k: TorusPrecision(2), + base2k: Base2K(17), + }; + + let mut lut: LookupTable = LookupTable::alloc(&lut_infos); + lut.set(&module, &f_vec, log_message_modulus + 1); + + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{} - {}", n_glwe, n_lwe)); + + group.bench_with_input(id, &(), |b, _| { + b.iter(|| { + brk_prepared.execute(&module, &mut res, &lwe, &lut, scratch.borrow()); + black_box(()) + }) + }); + group.finish(); +} + +fn bench_blind_rotate_fft64(c: &mut Criterion) { + #[cfg(all(feature = "enable-avx", target_arch = "x86_64"))] + let label = "fft64_avx"; + #[cfg(not(all(feature = "enable-avx", target_arch = "x86_64")))] + let label = "fft64_ref"; + benc_blind_rotate::(c, label); +} + +criterion_group!(benches, bench_blind_rotate_fft64); +criterion_main!(benches); diff --git a/poulpy-schemes/benches/circuit_bootstrapping.rs b/poulpy-schemes/benches/circuit_bootstrapping.rs index 855fdec..f8f25dc 100644 --- a/poulpy-schemes/benches/circuit_bootstrapping.rs +++ b/poulpy-schemes/benches/circuit_bootstrapping.rs @@ -105,14 +105,7 @@ where // Circuit bootstrapping evaluation key let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(¶ms.cbt_infos); - cbt_key.encrypt_sk( - &module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + cbt_key.encrypt_sk(&module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); let mut res: GGSW> = GGSW::alloc_from_infos(¶ms.ggsw_infos); let mut cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, BE> = diff --git a/poulpy-schemes/examples/bdd_arithmetic.rs b/poulpy-schemes/examples/bdd_arithmetic.rs index f1756bd..6daea8f 100644 --- a/poulpy-schemes/examples/bdd_arithmetic.rs +++ b/poulpy-schemes/examples/bdd_arithmetic.rs @@ -162,14 +162,7 @@ where // This key is required to prepare all Fhe Integers for operations, // and for performing the operations themselves let mut bdd_key: BDDKey, BRA> = BDDKey::alloc_from_infos(&bdd_layout); - bdd_key.encrypt_sk( - &module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + bdd_key.encrypt_sk(&module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); ////////// Input Encryption // Encrypting the inputs @@ -216,13 +209,7 @@ where let mut c_enc: FheUint, u32> = FheUint::alloc_from_infos(&glwe_layout); // Performing the operation - c_enc.add( - &module, - &a_enc_prepared, - &b_enc_prepared, - &bdd_key_prepared, - scratch.borrow(), - ); + c_enc.add(&module, &a_enc_prepared, &b_enc_prepared, &bdd_key_prepared, scratch.borrow()); // Preparing the intermediate result ciphertext, c_enc, for the next operation let mut c_enc_prepared: FheUintPrepared, u32, BE> = FheUintPrepared::alloc_from_infos(&module, &ggsw_layout); @@ -230,13 +217,7 @@ where // Creating the output ciphertext d_enc let mut selected_enc: FheUint, u32> = FheUint::alloc_from_infos(&glwe_layout); - selected_enc.xor( - &module, - &c_enc_prepared, - &a_enc_prepared, - &bdd_key_prepared, - scratch.borrow(), - ); + selected_enc.xor(&module, &c_enc_prepared, &a_enc_prepared, &bdd_key_prepared, scratch.borrow()); //////// Homomorphic computation ends here //////// @@ -301,12 +282,7 @@ where ); let mut input_selector_enc_prepared: FheUintPrepared, u32, BE> = FheUintPrepared::alloc_from_infos(&module, &ggsw_layout); - input_selector_enc_prepared.prepare( - &module, - &input_selector_enc, - &bdd_key_prepared, - scratch.borrow(), - ); + input_selector_enc_prepared.prepare(&module, &input_selector_enc, &bdd_key_prepared, scratch.borrow()); module.glwe_blind_selection( &mut selected_enc, diff --git a/poulpy-schemes/examples/circuit_bootstrapping.rs b/poulpy-schemes/examples/circuit_bootstrapping.rs index 7a0b30a..7a8b88b 100644 --- a/poulpy-schemes/examples/circuit_bootstrapping.rs +++ b/poulpy-schemes/examples/circuit_bootstrapping.rs @@ -8,10 +8,20 @@ use poulpy_core::{ }; use std::time::Instant; -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] use poulpy_cpu_avx::FFT64Avx as BackendImpl; -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] use poulpy_cpu_ref::FFT64Ref as BackendImpl; use poulpy_hal::{ @@ -162,28 +172,14 @@ fn main() { let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); // Encrypt LWE Plaintext - ct_lwe.encrypt_sk( - &module, - &pt_lwe, - &sk_lwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_lwe.encrypt_sk(&module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, scratch.borrow()); let now: Instant = Instant::now(); // Circuit bootstrapping evaluation key let mut cbt_key: CircuitBootstrappingKey, CGGI> = CircuitBootstrappingKey::alloc_from_infos(&cbt_layout); - cbt_key.encrypt_sk( - &module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + cbt_key.encrypt_sk(&module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); @@ -197,14 +193,7 @@ fn main() { // Apply circuit bootstrapping: LWE(data * 2^{- (k_lwe_pt + 2)}) -> GGSW(data) let now: Instant = Instant::now(); - cbt_prepared.execute_to_constant( - &module, - &mut res, - &ct_lwe, - k_lwe_pt, - extension_factor, - scratch.borrow(), - ); + cbt_prepared.execute_to_constant(&module, &mut res, &ct_lwe, k_lwe_pt, extension_factor, scratch.borrow()); println!("CBT: {} ms", now.elapsed().as_millis()); // Allocate "ideal" GGSW(data) plaintext @@ -216,16 +205,9 @@ fn main() { for col in 0..res.rank().as_usize() + 1 { println!( "row:{row} col:{col} -> {}", - res.noise( - &module, - row, - col, - &pt_ggsw, - &sk_glwe_prepared, - scratch.borrow() - ) - .std() - .log2() + res.noise(&module, row, col, &pt_ggsw, &sk_glwe_prepared, scratch.borrow()) + .std() + .log2() ) } } diff --git a/poulpy-schemes/examples/max_array.rs b/poulpy-schemes/examples/max_array.rs index 43ef444..3ae5b11 100644 --- a/poulpy-schemes/examples/max_array.rs +++ b/poulpy-schemes/examples/max_array.rs @@ -152,14 +152,7 @@ where // This key is required to prepare all Fhe Integers for operations, // and for performing the operations themselves let mut bdd_key: BDDKey, BRA> = BDDKey::alloc_from_infos(&bdd_layout); - bdd_key.encrypt_sk( - &module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + bdd_key.encrypt_sk(&module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); ////////// Input Encryption // Encrypting the inputs diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/blind_retrieval.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/blind_retrieval.rs index bcbd640..f4d971d 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/blind_retrieval.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/blind_retrieval.rs @@ -19,9 +19,7 @@ impl GLWEBlindRetriever { { let bit_size: usize = (u32::BITS - (size as u32 - 1).leading_zeros()) as usize; Self { - accumulators: (0..bit_size) - .map(|_| Accumulator::alloc(infos)) - .collect_vec(), + accumulators: (0..bit_size).map(|_| Accumulator::alloc(infos)).collect_vec(), counter: 0, } } @@ -70,15 +68,7 @@ impl GLWEBlindRetriever { 1 << self.accumulators.len() ); - add_core( - module, - a, - &mut self.accumulators, - 0, - selector, - offset, - scratch, - ); + add_core(module, a, &mut self.accumulators, 0, selector, offset, scratch); self.counter += 1; } @@ -92,15 +82,7 @@ impl GLWEBlindRetriever { for i in 0..self.accumulators.len() - 1 { let (acc_prev, acc_next) = self.accumulators.split_at_mut(i + 1); if acc_prev[i].num != 0 { - add_core( - module, - &acc_prev[i].data, - acc_next, - i + 1, - selector, - offset, - scratch, - ); + add_core(module, &acc_prev[i].data, acc_next, i + 1, selector, offset, scratch); acc_prev[0].num = 0 } } @@ -156,23 +138,10 @@ fn add_core( acc_prev[0].num = 1; } 1 => { - module.cmux_inplace_neg( - &mut acc_prev[0].data, - a, - &selector.get_bit(i + offset), - scratch, - ); + module.cmux_inplace_neg(&mut acc_prev[0].data, a, &selector.get_bit(i + offset), scratch); if !acc_next.is_empty() { - add_core( - module, - &acc_prev[0].data, - acc_next, - i + 1, - selector, - offset, - scratch, - ); + add_core(module, &acc_prev[0].data, acc_next, i + 1, selector, offset, scratch); } acc_prev[0].num = 0 diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/blind_rotation.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/blind_rotation.rs index 7e11149..5a272ae 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/blind_rotation.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/blind_rotation.rs @@ -48,15 +48,7 @@ where for col in 0..(res.rank() + 1).into() { for row in 0..res.dnum().into() { - self.glwe_blind_rotation_inplace( - &mut res.at_mut(row, col), - fhe_uint, - sign, - bit_rsh, - bit_mask, - bit_lsh, - scratch, - ); + self.glwe_blind_rotation_inplace(&mut res.at_mut(row, col), fhe_uint, sign, bit_rsh, bit_mask, bit_lsh, scratch); } } } @@ -137,13 +129,7 @@ where for col in 0..(res.rank() + 1).into() { for row in 0..res.dnum().into() { tmp_glwe.data_mut().zero(); - self.vec_znx_add_scalar_inplace( - tmp_glwe.data_mut(), - col, - (dsize - 1) + row * dsize, - test_vector, - 0, - ); + self.vec_znx_add_scalar_inplace(tmp_glwe.data_mut(), col, (dsize - 1) + row * dsize, test_vector, 0); self.vec_znx_normalize_inplace(base2k, tmp_glwe.data_mut(), col, scratch_1); self.glwe_blind_rotation( diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint.rs index 5f1fc09..f45c0a9 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint.rs @@ -121,8 +121,7 @@ impl FheUint { let (mut pt, scratch_1) = scratch.take_glwe_plaintext(&pt_infos); pt.encode_vec_i64(&data_bits, TorusPrecision(2)); - self.bits - .encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, scratch_1); + self.bits.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, scratch_1); } } @@ -233,15 +232,7 @@ impl FheUint { let (mut tmp, scratch_1) = scratch.take_fhe_uint(self); tmp.splice_u8(module, dst << 1, src << 1, a, b, keys, scratch_1); - self.splice_u8( - module, - (dst << 1) + 1, - (src << 1) + 1, - &tmp, - b, - keys, - scratch_1, - ); + self.splice_u8(module, (dst << 1) + 1, (src << 1) + 1, &tmp, b, keys, scratch_1); } #[allow(clippy::too_many_arguments)] @@ -279,11 +270,7 @@ impl FheUint { let (mut tmp_fhe_uint_byte, scratch_1) = scratch.take_fhe_uint(b); // Move a[byte_a] into a[dst] - module.glwe_rotate( - -((T::bit_index(src << 3) << log_gap) as i64), - &mut tmp_fhe_uint_byte, - b, - ); + module.glwe_rotate(-((T::bit_index(src << 3) << log_gap) as i64), &mut tmp_fhe_uint_byte, b); // Zeroes all other bytes module.glwe_trace_inplace(&mut tmp_fhe_uint_byte, trace_start, keys, scratch_1); @@ -348,13 +335,8 @@ impl FheUint { rank: ks_lwe.rank_out(), }); module.glwe_keyswitch(&mut res_tmp, self, ks_glwe, scratch_1); - res.to_mut().from_glwe( - module, - &res_tmp, - T::bit_index(bit) << log_gap, - ks_lwe, - scratch_1, - ); + res.to_mut() + .from_glwe(module, &res_tmp, T::bit_index(bit) << log_gap, ks_lwe, scratch_1); } else { res.to_mut() .from_glwe(module, self, T::bit_index(bit) << log_gap, ks_lwe, scratch); @@ -415,8 +397,7 @@ impl FheUint { { let zero: GLWE> = GLWE::alloc_from_infos(self); let mut one: GLWE> = GLWE::alloc_from_infos(self); - one.data_mut() - .encode_coeff_i64(self.base2k().into(), 0, 2, 0, 1); + one.data_mut().encode_coeff_i64(self.base2k().into(), 0, 2, 0, 1); let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self); diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs index d3be3e0..76bfe89 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs @@ -58,10 +58,7 @@ impl GetGGSWBitMut for FheUi } fn get_bits(&mut self, start: usize, count: usize) -> Vec> { assert!(start + count <= self.bits.len()); - self.bits[start..start + count] - .iter_mut() - .map(|bit| bit.to_mut()) - .collect() + self.bits[start..start + count].iter_mut().map(|bit| bit.to_mut()).collect() } } @@ -95,13 +92,7 @@ where where A: GGSWInfos, { - self.alloc_fhe_uint_prepared( - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank(), - ) + self.alloc_fhe_uint_prepared(infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank()) } } @@ -316,12 +307,8 @@ where A: GLWEInfos, B: BDDKeyInfos, { - self.circuit_bootstrapping_execute_tmp_bytes( - block_size, - extension_factor, - res_infos, - &bdd_infos.cbt_infos(), - ) + GGSW::bytes_of_from_infos(res_infos) + self.circuit_bootstrapping_execute_tmp_bytes(block_size, extension_factor, res_infos, &bdd_infos.cbt_infos()) + + GGSW::bytes_of_from_infos(res_infos) + LWE::bytes_of_from_infos(bits_infos) } @@ -371,14 +358,7 @@ where let (mut tmp_ggsw, scratch_1) = scratch_thread.take_ggsw(ggsw_infos); let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits); for (local_bit, dst) in res_bits_chunk.iter_mut().enumerate() { - bits.get_bit_lwe( - self, - start + local_bit, - &mut tmp_lwe, - ks_glwe, - ks_lwe, - scratch_2, - ); + bits.get_bit_lwe(self, start + local_bit, &mut tmp_lwe, ks_glwe, ks_lwe, scratch_2); cbt.execute_to_constant(self, &mut tmp_ggsw, &tmp_lwe, 1, 1, scratch_2); dst.prepare(self, &tmp_ggsw, scratch_2); } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs index 7b4066b..bb54d40 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs @@ -27,14 +27,7 @@ impl FheUintPreparedDebug, T> { M: ModuleN, A: GGSWInfos, { - Self::alloc( - module, - infos.base2k(), - infos.k(), - infos.dnum(), - infos.dsize(), - infos.rank(), - ) + Self::alloc(module, infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank()) } pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> Self @@ -126,16 +119,8 @@ where let (_, scratch_1) = scratch.take_ggsw(res); let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits); for (bit, dst) in res.bits.iter_mut().enumerate() { - bits.get_bit_lwe( - self, - bit, - &mut tmp_lwe, - key.ks_glwe.as_ref(), - &key.ks_lwe, - scratch_2, - ); - key.cbt - .execute_to_constant(self, dst, &tmp_lwe, 1, 1, scratch_2); + bits.get_bit_lwe(self, bit, &mut tmp_lwe, key.ks_glwe.as_ref(), &key.ks_lwe, scratch_2); + key.cbt.execute_to_constant(self, dst, &tmp_lwe, 1, 1, scratch_2); } } } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/add_codegen.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/add_codegen.rs index 0ad70f3..5765490 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/add_codegen.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/add_codegen.rs @@ -79,12 +79,7 @@ impl BitCircuitFamily for AnyBitCircuit { pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ AnyBitCircuit::B0(BitCircuit::new( - [ - Node::Cmux(32, 1, 0), - Node::Cmux(32, 0, 1), - Node::Cmux(0, 1, 0), - Node::None, - ], + [Node::Cmux(32, 1, 0), Node::Cmux(32, 0, 1), Node::Cmux(0, 1, 0), Node::None], 2, )), AnyBitCircuit::B1(BitCircuit::new( diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/and_codegen.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/and_codegen.rs index d38c827..2bfd1ae 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/and_codegen.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/and_codegen.rs @@ -79,291 +79,131 @@ impl BitCircuitFamily for AnyBitCircuit { pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ AnyBitCircuit::B0(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(32, 1, 0), - Node::Cmux(0, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(32, 1, 0), Node::Cmux(0, 1, 0), Node::None], 2, )), AnyBitCircuit::B1(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(33, 1, 0), - Node::Cmux(1, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(33, 1, 0), Node::Cmux(1, 1, 0), Node::None], 2, )), AnyBitCircuit::B2(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(34, 1, 0), - Node::Cmux(2, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(34, 1, 0), Node::Cmux(2, 1, 0), Node::None], 2, )), AnyBitCircuit::B3(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(35, 1, 0), - Node::Cmux(3, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(35, 1, 0), Node::Cmux(3, 1, 0), Node::None], 2, )), AnyBitCircuit::B4(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(36, 1, 0), - Node::Cmux(4, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(36, 1, 0), Node::Cmux(4, 1, 0), Node::None], 2, )), AnyBitCircuit::B5(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(37, 1, 0), - Node::Cmux(5, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(37, 1, 0), Node::Cmux(5, 1, 0), Node::None], 2, )), AnyBitCircuit::B6(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(38, 1, 0), - Node::Cmux(6, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(38, 1, 0), Node::Cmux(6, 1, 0), Node::None], 2, )), AnyBitCircuit::B7(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(39, 1, 0), - Node::Cmux(7, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(39, 1, 0), Node::Cmux(7, 1, 0), Node::None], 2, )), AnyBitCircuit::B8(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(40, 1, 0), - Node::Cmux(8, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(40, 1, 0), Node::Cmux(8, 1, 0), Node::None], 2, )), AnyBitCircuit::B9(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(41, 1, 0), - Node::Cmux(9, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(41, 1, 0), Node::Cmux(9, 1, 0), Node::None], 2, )), AnyBitCircuit::B10(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(42, 1, 0), - Node::Cmux(10, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(42, 1, 0), Node::Cmux(10, 1, 0), Node::None], 2, )), AnyBitCircuit::B11(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(43, 1, 0), - Node::Cmux(11, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(43, 1, 0), Node::Cmux(11, 1, 0), Node::None], 2, )), AnyBitCircuit::B12(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(44, 1, 0), - Node::Cmux(12, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(44, 1, 0), Node::Cmux(12, 1, 0), Node::None], 2, )), AnyBitCircuit::B13(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(45, 1, 0), - Node::Cmux(13, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(45, 1, 0), Node::Cmux(13, 1, 0), Node::None], 2, )), AnyBitCircuit::B14(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(46, 1, 0), - Node::Cmux(14, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(46, 1, 0), Node::Cmux(14, 1, 0), Node::None], 2, )), AnyBitCircuit::B15(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(47, 1, 0), - Node::Cmux(15, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(47, 1, 0), Node::Cmux(15, 1, 0), Node::None], 2, )), AnyBitCircuit::B16(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(48, 1, 0), - Node::Cmux(16, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(48, 1, 0), Node::Cmux(16, 1, 0), Node::None], 2, )), AnyBitCircuit::B17(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(49, 1, 0), - Node::Cmux(17, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(49, 1, 0), Node::Cmux(17, 1, 0), Node::None], 2, )), AnyBitCircuit::B18(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(50, 1, 0), - Node::Cmux(18, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(50, 1, 0), Node::Cmux(18, 1, 0), Node::None], 2, )), AnyBitCircuit::B19(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(51, 1, 0), - Node::Cmux(19, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(51, 1, 0), Node::Cmux(19, 1, 0), Node::None], 2, )), AnyBitCircuit::B20(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(52, 1, 0), - Node::Cmux(20, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(52, 1, 0), Node::Cmux(20, 1, 0), Node::None], 2, )), AnyBitCircuit::B21(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(53, 1, 0), - Node::Cmux(21, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(53, 1, 0), Node::Cmux(21, 1, 0), Node::None], 2, )), AnyBitCircuit::B22(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(54, 1, 0), - Node::Cmux(22, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(54, 1, 0), Node::Cmux(22, 1, 0), Node::None], 2, )), AnyBitCircuit::B23(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(55, 1, 0), - Node::Cmux(23, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(55, 1, 0), Node::Cmux(23, 1, 0), Node::None], 2, )), AnyBitCircuit::B24(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(56, 1, 0), - Node::Cmux(24, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(56, 1, 0), Node::Cmux(24, 1, 0), Node::None], 2, )), AnyBitCircuit::B25(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(57, 1, 0), - Node::Cmux(25, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(57, 1, 0), Node::Cmux(25, 1, 0), Node::None], 2, )), AnyBitCircuit::B26(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(58, 1, 0), - Node::Cmux(26, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(58, 1, 0), Node::Cmux(26, 1, 0), Node::None], 2, )), AnyBitCircuit::B27(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(59, 1, 0), - Node::Cmux(27, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(59, 1, 0), Node::Cmux(27, 1, 0), Node::None], 2, )), AnyBitCircuit::B28(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(60, 1, 0), - Node::Cmux(28, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(60, 1, 0), Node::Cmux(28, 1, 0), Node::None], 2, )), AnyBitCircuit::B29(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(61, 1, 0), - Node::Cmux(29, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(61, 1, 0), Node::Cmux(29, 1, 0), Node::None], 2, )), AnyBitCircuit::B30(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(62, 1, 0), - Node::Cmux(30, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(62, 1, 0), Node::Cmux(30, 1, 0), Node::None], 2, )), AnyBitCircuit::B31(BitCircuit::new( - [ - Node::Copy, - Node::Cmux(63, 1, 0), - Node::Cmux(31, 1, 0), - Node::None, - ], + [Node::Copy, Node::Cmux(63, 1, 0), Node::Cmux(31, 1, 0), Node::None], 2, )), ]); diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/or_codegen.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/or_codegen.rs index 5dac043..ff58ff6 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/or_codegen.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/or_codegen.rs @@ -79,291 +79,131 @@ impl BitCircuitFamily for AnyBitCircuit { pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ AnyBitCircuit::B0(BitCircuit::new( - [ - Node::Cmux(32, 1, 0), - Node::Copy, - Node::Cmux(0, 1, 0), - Node::None, - ], + [Node::Cmux(32, 1, 0), Node::Copy, Node::Cmux(0, 1, 0), Node::None], 2, )), AnyBitCircuit::B1(BitCircuit::new( - [ - Node::Cmux(33, 1, 0), - Node::Copy, - Node::Cmux(1, 1, 0), - Node::None, - ], + [Node::Cmux(33, 1, 0), Node::Copy, Node::Cmux(1, 1, 0), Node::None], 2, )), AnyBitCircuit::B2(BitCircuit::new( - [ - Node::Cmux(34, 1, 0), - Node::Copy, - Node::Cmux(2, 1, 0), - Node::None, - ], + [Node::Cmux(34, 1, 0), Node::Copy, Node::Cmux(2, 1, 0), Node::None], 2, )), AnyBitCircuit::B3(BitCircuit::new( - [ - Node::Cmux(35, 1, 0), - Node::Copy, - Node::Cmux(3, 1, 0), - Node::None, - ], + [Node::Cmux(35, 1, 0), Node::Copy, Node::Cmux(3, 1, 0), Node::None], 2, )), AnyBitCircuit::B4(BitCircuit::new( - [ - Node::Cmux(36, 1, 0), - Node::Copy, - Node::Cmux(4, 1, 0), - Node::None, - ], + [Node::Cmux(36, 1, 0), Node::Copy, Node::Cmux(4, 1, 0), Node::None], 2, )), AnyBitCircuit::B5(BitCircuit::new( - [ - Node::Cmux(37, 1, 0), - Node::Copy, - Node::Cmux(5, 1, 0), - Node::None, - ], + [Node::Cmux(37, 1, 0), Node::Copy, Node::Cmux(5, 1, 0), Node::None], 2, )), AnyBitCircuit::B6(BitCircuit::new( - [ - Node::Cmux(38, 1, 0), - Node::Copy, - Node::Cmux(6, 1, 0), - Node::None, - ], + [Node::Cmux(38, 1, 0), Node::Copy, Node::Cmux(6, 1, 0), Node::None], 2, )), AnyBitCircuit::B7(BitCircuit::new( - [ - Node::Cmux(39, 1, 0), - Node::Copy, - Node::Cmux(7, 1, 0), - Node::None, - ], + [Node::Cmux(39, 1, 0), Node::Copy, Node::Cmux(7, 1, 0), Node::None], 2, )), AnyBitCircuit::B8(BitCircuit::new( - [ - Node::Cmux(40, 1, 0), - Node::Copy, - Node::Cmux(8, 1, 0), - Node::None, - ], + [Node::Cmux(40, 1, 0), Node::Copy, Node::Cmux(8, 1, 0), Node::None], 2, )), AnyBitCircuit::B9(BitCircuit::new( - [ - Node::Cmux(41, 1, 0), - Node::Copy, - Node::Cmux(9, 1, 0), - Node::None, - ], + [Node::Cmux(41, 1, 0), Node::Copy, Node::Cmux(9, 1, 0), Node::None], 2, )), AnyBitCircuit::B10(BitCircuit::new( - [ - Node::Cmux(42, 1, 0), - Node::Copy, - Node::Cmux(10, 1, 0), - Node::None, - ], + [Node::Cmux(42, 1, 0), Node::Copy, Node::Cmux(10, 1, 0), Node::None], 2, )), AnyBitCircuit::B11(BitCircuit::new( - [ - Node::Cmux(43, 1, 0), - Node::Copy, - Node::Cmux(11, 1, 0), - Node::None, - ], + [Node::Cmux(43, 1, 0), Node::Copy, Node::Cmux(11, 1, 0), Node::None], 2, )), AnyBitCircuit::B12(BitCircuit::new( - [ - Node::Cmux(44, 1, 0), - Node::Copy, - Node::Cmux(12, 1, 0), - Node::None, - ], + [Node::Cmux(44, 1, 0), Node::Copy, Node::Cmux(12, 1, 0), Node::None], 2, )), AnyBitCircuit::B13(BitCircuit::new( - [ - Node::Cmux(45, 1, 0), - Node::Copy, - Node::Cmux(13, 1, 0), - Node::None, - ], + [Node::Cmux(45, 1, 0), Node::Copy, Node::Cmux(13, 1, 0), Node::None], 2, )), AnyBitCircuit::B14(BitCircuit::new( - [ - Node::Cmux(46, 1, 0), - Node::Copy, - Node::Cmux(14, 1, 0), - Node::None, - ], + [Node::Cmux(46, 1, 0), Node::Copy, Node::Cmux(14, 1, 0), Node::None], 2, )), AnyBitCircuit::B15(BitCircuit::new( - [ - Node::Cmux(47, 1, 0), - Node::Copy, - Node::Cmux(15, 1, 0), - Node::None, - ], + [Node::Cmux(47, 1, 0), Node::Copy, Node::Cmux(15, 1, 0), Node::None], 2, )), AnyBitCircuit::B16(BitCircuit::new( - [ - Node::Cmux(48, 1, 0), - Node::Copy, - Node::Cmux(16, 1, 0), - Node::None, - ], + [Node::Cmux(48, 1, 0), Node::Copy, Node::Cmux(16, 1, 0), Node::None], 2, )), AnyBitCircuit::B17(BitCircuit::new( - [ - Node::Cmux(49, 1, 0), - Node::Copy, - Node::Cmux(17, 1, 0), - Node::None, - ], + [Node::Cmux(49, 1, 0), Node::Copy, Node::Cmux(17, 1, 0), Node::None], 2, )), AnyBitCircuit::B18(BitCircuit::new( - [ - Node::Cmux(50, 1, 0), - Node::Copy, - Node::Cmux(18, 1, 0), - Node::None, - ], + [Node::Cmux(50, 1, 0), Node::Copy, Node::Cmux(18, 1, 0), Node::None], 2, )), AnyBitCircuit::B19(BitCircuit::new( - [ - Node::Cmux(51, 1, 0), - Node::Copy, - Node::Cmux(19, 1, 0), - Node::None, - ], + [Node::Cmux(51, 1, 0), Node::Copy, Node::Cmux(19, 1, 0), Node::None], 2, )), AnyBitCircuit::B20(BitCircuit::new( - [ - Node::Cmux(52, 1, 0), - Node::Copy, - Node::Cmux(20, 1, 0), - Node::None, - ], + [Node::Cmux(52, 1, 0), Node::Copy, Node::Cmux(20, 1, 0), Node::None], 2, )), AnyBitCircuit::B21(BitCircuit::new( - [ - Node::Cmux(53, 1, 0), - Node::Copy, - Node::Cmux(21, 1, 0), - Node::None, - ], + [Node::Cmux(53, 1, 0), Node::Copy, Node::Cmux(21, 1, 0), Node::None], 2, )), AnyBitCircuit::B22(BitCircuit::new( - [ - Node::Cmux(54, 1, 0), - Node::Copy, - Node::Cmux(22, 1, 0), - Node::None, - ], + [Node::Cmux(54, 1, 0), Node::Copy, Node::Cmux(22, 1, 0), Node::None], 2, )), AnyBitCircuit::B23(BitCircuit::new( - [ - Node::Cmux(55, 1, 0), - Node::Copy, - Node::Cmux(23, 1, 0), - Node::None, - ], + [Node::Cmux(55, 1, 0), Node::Copy, Node::Cmux(23, 1, 0), Node::None], 2, )), AnyBitCircuit::B24(BitCircuit::new( - [ - Node::Cmux(56, 1, 0), - Node::Copy, - Node::Cmux(24, 1, 0), - Node::None, - ], + [Node::Cmux(56, 1, 0), Node::Copy, Node::Cmux(24, 1, 0), Node::None], 2, )), AnyBitCircuit::B25(BitCircuit::new( - [ - Node::Cmux(57, 1, 0), - Node::Copy, - Node::Cmux(25, 1, 0), - Node::None, - ], + [Node::Cmux(57, 1, 0), Node::Copy, Node::Cmux(25, 1, 0), Node::None], 2, )), AnyBitCircuit::B26(BitCircuit::new( - [ - Node::Cmux(58, 1, 0), - Node::Copy, - Node::Cmux(26, 1, 0), - Node::None, - ], + [Node::Cmux(58, 1, 0), Node::Copy, Node::Cmux(26, 1, 0), Node::None], 2, )), AnyBitCircuit::B27(BitCircuit::new( - [ - Node::Cmux(59, 1, 0), - Node::Copy, - Node::Cmux(27, 1, 0), - Node::None, - ], + [Node::Cmux(59, 1, 0), Node::Copy, Node::Cmux(27, 1, 0), Node::None], 2, )), AnyBitCircuit::B28(BitCircuit::new( - [ - Node::Cmux(60, 1, 0), - Node::Copy, - Node::Cmux(28, 1, 0), - Node::None, - ], + [Node::Cmux(60, 1, 0), Node::Copy, Node::Cmux(28, 1, 0), Node::None], 2, )), AnyBitCircuit::B29(BitCircuit::new( - [ - Node::Cmux(61, 1, 0), - Node::Copy, - Node::Cmux(29, 1, 0), - Node::None, - ], + [Node::Cmux(61, 1, 0), Node::Copy, Node::Cmux(29, 1, 0), Node::None], 2, )), AnyBitCircuit::B30(BitCircuit::new( - [ - Node::Cmux(62, 1, 0), - Node::Copy, - Node::Cmux(30, 1, 0), - Node::None, - ], + [Node::Cmux(62, 1, 0), Node::Copy, Node::Cmux(30, 1, 0), Node::None], 2, )), AnyBitCircuit::B31(BitCircuit::new( - [ - Node::Cmux(63, 1, 0), - Node::Copy, - Node::Cmux(31, 1, 0), - Node::None, - ], + [Node::Cmux(63, 1, 0), Node::Copy, Node::Cmux(31, 1, 0), Node::None], 2, )), ]); diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/sub_codegen.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/sub_codegen.rs index e0454e0..50caacd 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/sub_codegen.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/sub_codegen.rs @@ -79,12 +79,7 @@ impl BitCircuitFamily for AnyBitCircuit { pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ AnyBitCircuit::B0(BitCircuit::new( - [ - Node::Cmux(32, 1, 0), - Node::Cmux(32, 0, 1), - Node::Cmux(0, 1, 0), - Node::None, - ], + [Node::Cmux(32, 1, 0), Node::Cmux(32, 0, 1), Node::Cmux(0, 1, 0), Node::None], 2, )), AnyBitCircuit::B1(BitCircuit::new( diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/xor_codegen.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/xor_codegen.rs index 46b2cf2..b1b0f9f 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/xor_codegen.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/circuits/u32/xor_codegen.rs @@ -79,291 +79,131 @@ impl BitCircuitFamily for AnyBitCircuit { pub(crate) static OUTPUT_CIRCUITS: Circuit = Circuit([ AnyBitCircuit::B0(BitCircuit::new( - [ - Node::Cmux(32, 0, 1), - Node::Cmux(32, 1, 0), - Node::Cmux(0, 0, 1), - Node::None, - ], + [Node::Cmux(32, 0, 1), Node::Cmux(32, 1, 0), Node::Cmux(0, 0, 1), Node::None], 2, )), AnyBitCircuit::B1(BitCircuit::new( - [ - Node::Cmux(33, 0, 1), - Node::Cmux(33, 1, 0), - Node::Cmux(1, 0, 1), - Node::None, - ], + [Node::Cmux(33, 0, 1), Node::Cmux(33, 1, 0), Node::Cmux(1, 0, 1), Node::None], 2, )), AnyBitCircuit::B2(BitCircuit::new( - [ - Node::Cmux(34, 1, 0), - Node::Cmux(34, 0, 1), - Node::Cmux(2, 1, 0), - Node::None, - ], + [Node::Cmux(34, 1, 0), Node::Cmux(34, 0, 1), Node::Cmux(2, 1, 0), Node::None], 2, )), AnyBitCircuit::B3(BitCircuit::new( - [ - Node::Cmux(35, 0, 1), - Node::Cmux(35, 1, 0), - Node::Cmux(3, 0, 1), - Node::None, - ], + [Node::Cmux(35, 0, 1), Node::Cmux(35, 1, 0), Node::Cmux(3, 0, 1), Node::None], 2, )), AnyBitCircuit::B4(BitCircuit::new( - [ - Node::Cmux(36, 0, 1), - Node::Cmux(36, 1, 0), - Node::Cmux(4, 0, 1), - Node::None, - ], + [Node::Cmux(36, 0, 1), Node::Cmux(36, 1, 0), Node::Cmux(4, 0, 1), Node::None], 2, )), AnyBitCircuit::B5(BitCircuit::new( - [ - Node::Cmux(37, 0, 1), - Node::Cmux(37, 1, 0), - Node::Cmux(5, 0, 1), - Node::None, - ], + [Node::Cmux(37, 0, 1), Node::Cmux(37, 1, 0), Node::Cmux(5, 0, 1), Node::None], 2, )), AnyBitCircuit::B6(BitCircuit::new( - [ - Node::Cmux(38, 1, 0), - Node::Cmux(38, 0, 1), - Node::Cmux(6, 1, 0), - Node::None, - ], + [Node::Cmux(38, 1, 0), Node::Cmux(38, 0, 1), Node::Cmux(6, 1, 0), Node::None], 2, )), AnyBitCircuit::B7(BitCircuit::new( - [ - Node::Cmux(39, 1, 0), - Node::Cmux(39, 0, 1), - Node::Cmux(7, 1, 0), - Node::None, - ], + [Node::Cmux(39, 1, 0), Node::Cmux(39, 0, 1), Node::Cmux(7, 1, 0), Node::None], 2, )), AnyBitCircuit::B8(BitCircuit::new( - [ - Node::Cmux(40, 0, 1), - Node::Cmux(40, 1, 0), - Node::Cmux(8, 0, 1), - Node::None, - ], + [Node::Cmux(40, 0, 1), Node::Cmux(40, 1, 0), Node::Cmux(8, 0, 1), Node::None], 2, )), AnyBitCircuit::B9(BitCircuit::new( - [ - Node::Cmux(41, 1, 0), - Node::Cmux(41, 0, 1), - Node::Cmux(9, 1, 0), - Node::None, - ], + [Node::Cmux(41, 1, 0), Node::Cmux(41, 0, 1), Node::Cmux(9, 1, 0), Node::None], 2, )), AnyBitCircuit::B10(BitCircuit::new( - [ - Node::Cmux(42, 0, 1), - Node::Cmux(42, 1, 0), - Node::Cmux(10, 0, 1), - Node::None, - ], + [Node::Cmux(42, 0, 1), Node::Cmux(42, 1, 0), Node::Cmux(10, 0, 1), Node::None], 2, )), AnyBitCircuit::B11(BitCircuit::new( - [ - Node::Cmux(43, 0, 1), - Node::Cmux(43, 1, 0), - Node::Cmux(11, 0, 1), - Node::None, - ], + [Node::Cmux(43, 0, 1), Node::Cmux(43, 1, 0), Node::Cmux(11, 0, 1), Node::None], 2, )), AnyBitCircuit::B12(BitCircuit::new( - [ - Node::Cmux(44, 0, 1), - Node::Cmux(44, 1, 0), - Node::Cmux(12, 0, 1), - Node::None, - ], + [Node::Cmux(44, 0, 1), Node::Cmux(44, 1, 0), Node::Cmux(12, 0, 1), Node::None], 2, )), AnyBitCircuit::B13(BitCircuit::new( - [ - Node::Cmux(45, 1, 0), - Node::Cmux(45, 0, 1), - Node::Cmux(13, 1, 0), - Node::None, - ], + [Node::Cmux(45, 1, 0), Node::Cmux(45, 0, 1), Node::Cmux(13, 1, 0), Node::None], 2, )), AnyBitCircuit::B14(BitCircuit::new( - [ - Node::Cmux(46, 1, 0), - Node::Cmux(46, 0, 1), - Node::Cmux(14, 1, 0), - Node::None, - ], + [Node::Cmux(46, 1, 0), Node::Cmux(46, 0, 1), Node::Cmux(14, 1, 0), Node::None], 2, )), AnyBitCircuit::B15(BitCircuit::new( - [ - Node::Cmux(47, 1, 0), - Node::Cmux(47, 0, 1), - Node::Cmux(15, 1, 0), - Node::None, - ], + [Node::Cmux(47, 1, 0), Node::Cmux(47, 0, 1), Node::Cmux(15, 1, 0), Node::None], 2, )), AnyBitCircuit::B16(BitCircuit::new( - [ - Node::Cmux(48, 0, 1), - Node::Cmux(48, 1, 0), - Node::Cmux(16, 0, 1), - Node::None, - ], + [Node::Cmux(48, 0, 1), Node::Cmux(48, 1, 0), Node::Cmux(16, 0, 1), Node::None], 2, )), AnyBitCircuit::B17(BitCircuit::new( - [ - Node::Cmux(49, 0, 1), - Node::Cmux(49, 1, 0), - Node::Cmux(17, 0, 1), - Node::None, - ], + [Node::Cmux(49, 0, 1), Node::Cmux(49, 1, 0), Node::Cmux(17, 0, 1), Node::None], 2, )), AnyBitCircuit::B18(BitCircuit::new( - [ - Node::Cmux(50, 0, 1), - Node::Cmux(50, 1, 0), - Node::Cmux(18, 0, 1), - Node::None, - ], + [Node::Cmux(50, 0, 1), Node::Cmux(50, 1, 0), Node::Cmux(18, 0, 1), Node::None], 2, )), AnyBitCircuit::B19(BitCircuit::new( - [ - Node::Cmux(51, 1, 0), - Node::Cmux(51, 0, 1), - Node::Cmux(19, 1, 0), - Node::None, - ], + [Node::Cmux(51, 1, 0), Node::Cmux(51, 0, 1), Node::Cmux(19, 1, 0), Node::None], 2, )), AnyBitCircuit::B20(BitCircuit::new( - [ - Node::Cmux(52, 1, 0), - Node::Cmux(52, 0, 1), - Node::Cmux(20, 1, 0), - Node::None, - ], + [Node::Cmux(52, 1, 0), Node::Cmux(52, 0, 1), Node::Cmux(20, 1, 0), Node::None], 2, )), AnyBitCircuit::B21(BitCircuit::new( - [ - Node::Cmux(53, 0, 1), - Node::Cmux(53, 1, 0), - Node::Cmux(21, 0, 1), - Node::None, - ], + [Node::Cmux(53, 0, 1), Node::Cmux(53, 1, 0), Node::Cmux(21, 0, 1), Node::None], 2, )), AnyBitCircuit::B22(BitCircuit::new( - [ - Node::Cmux(54, 1, 0), - Node::Cmux(54, 0, 1), - Node::Cmux(22, 1, 0), - Node::None, - ], + [Node::Cmux(54, 1, 0), Node::Cmux(54, 0, 1), Node::Cmux(22, 1, 0), Node::None], 2, )), AnyBitCircuit::B23(BitCircuit::new( - [ - Node::Cmux(55, 1, 0), - Node::Cmux(55, 0, 1), - Node::Cmux(23, 1, 0), - Node::None, - ], + [Node::Cmux(55, 1, 0), Node::Cmux(55, 0, 1), Node::Cmux(23, 1, 0), Node::None], 2, )), AnyBitCircuit::B24(BitCircuit::new( - [ - Node::Cmux(56, 0, 1), - Node::Cmux(56, 1, 0), - Node::Cmux(24, 0, 1), - Node::None, - ], + [Node::Cmux(56, 0, 1), Node::Cmux(56, 1, 0), Node::Cmux(24, 0, 1), Node::None], 2, )), AnyBitCircuit::B25(BitCircuit::new( - [ - Node::Cmux(57, 0, 1), - Node::Cmux(57, 1, 0), - Node::Cmux(25, 0, 1), - Node::None, - ], + [Node::Cmux(57, 0, 1), Node::Cmux(57, 1, 0), Node::Cmux(25, 0, 1), Node::None], 2, )), AnyBitCircuit::B26(BitCircuit::new( - [ - Node::Cmux(58, 1, 0), - Node::Cmux(58, 0, 1), - Node::Cmux(26, 1, 0), - Node::None, - ], + [Node::Cmux(58, 1, 0), Node::Cmux(58, 0, 1), Node::Cmux(26, 1, 0), Node::None], 2, )), AnyBitCircuit::B27(BitCircuit::new( - [ - Node::Cmux(59, 0, 1), - Node::Cmux(59, 1, 0), - Node::Cmux(27, 0, 1), - Node::None, - ], + [Node::Cmux(59, 0, 1), Node::Cmux(59, 1, 0), Node::Cmux(27, 0, 1), Node::None], 2, )), AnyBitCircuit::B28(BitCircuit::new( - [ - Node::Cmux(60, 0, 1), - Node::Cmux(60, 1, 0), - Node::Cmux(28, 0, 1), - Node::None, - ], + [Node::Cmux(60, 0, 1), Node::Cmux(60, 1, 0), Node::Cmux(28, 0, 1), Node::None], 2, )), AnyBitCircuit::B29(BitCircuit::new( - [ - Node::Cmux(61, 0, 1), - Node::Cmux(61, 1, 0), - Node::Cmux(29, 0, 1), - Node::None, - ], + [Node::Cmux(61, 0, 1), Node::Cmux(61, 1, 0), Node::Cmux(29, 0, 1), Node::None], 2, )), AnyBitCircuit::B30(BitCircuit::new( - [ - Node::Cmux(62, 0, 1), - Node::Cmux(62, 1, 0), - Node::Cmux(30, 0, 1), - Node::None, - ], + [Node::Cmux(62, 0, 1), Node::Cmux(62, 1, 0), Node::Cmux(30, 0, 1), Node::None], 2, )), AnyBitCircuit::B31(BitCircuit::new( - [ - Node::Cmux(63, 1, 0), - Node::Cmux(63, 0, 1), - Node::Cmux(31, 1, 0), - Node::None, - ], + [Node::Cmux(63, 1, 0), Node::Cmux(63, 0, 1), Node::Cmux(31, 1, 0), Node::None], 2, )), ]); diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/eval.rs index 66e1c02..988bcaf 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/eval.rs @@ -194,9 +194,7 @@ fn eval_level( level.iter_mut().for_each(|ct| ct.data_mut().zero()); // TODO: implement API on GLWE - level[1] - .data_mut() - .encode_coeff_i64(res.base2k().into(), 0, 2, 0, 1); + level[1].data_mut().encode_coeff_i64(res.base2k().into(), 0, 2, 0, 1); let mut level_ref: Vec<&mut GLWE<&mut [u8]>> = level.iter_mut().collect_vec(); let (mut prev_level, mut next_level) = level_ref.split_at_mut(state_size); @@ -243,10 +241,7 @@ fn eval_level( impl BitCircuit { pub const fn new(nodes: [Node; N], max_inter_state: usize) -> Self { - Self { - nodes, - max_inter_state, - } + Self { nodes, max_inter_state } } } impl BitCircuitInfo for BitCircuit { @@ -369,29 +364,13 @@ where // res_a = (b-a) * bit + a for j in 0..(res_a.rank() + 1).into() { self.vec_znx_big_add_small(&mut res_big_tmp, 0, &res_big, j, res_a.data(), j); - self.vec_znx_big_normalize( - res_base2k, - res_a.data_mut(), - j, - s_base2k, - &res_big_tmp, - 0, - scratch_2, - ); + self.vec_znx_big_normalize(res_a.data_mut(), res_base2k, 0, j, &res_big_tmp, s_base2k, 0, scratch_2); } // res_b = a - (a - b) * bit = (b - a) * bit + a for j in 0..(res_b.rank() + 1).into() { self.vec_znx_big_sub_small_a(&mut res_big_tmp, 0, res_b.data(), j, &res_big, j); - self.vec_znx_big_normalize( - res_base2k, - res_b.data_mut(), - j, - s_base2k, - &res_big_tmp, - 0, - scratch_2, - ); + self.vec_znx_big_normalize(res_b.data_mut(), res_base2k, 0, j, &res_big_tmp, s_base2k, 0, scratch_2); } } else { let (mut tmp_a, scratch_1) = scratch.take_glwe(&GLWELayout { @@ -432,29 +411,13 @@ where // res_a = (b-a) * bit + a for j in 0..(res_a.rank() + 1).into() { self.vec_znx_big_add_small(&mut res_big_tmp, 0, &res_big, j, tmp_a.data(), j); - self.vec_znx_big_normalize( - res_base2k, - res_a.data_mut(), - j, - s_base2k, - &res_big_tmp, - 0, - scratch_4, - ); + self.vec_znx_big_normalize(res_a.data_mut(), res_base2k, 0, j, &res_big_tmp, s_base2k, 0, scratch_4); } // res_b = a - (a - b) * bit = (b - a) * bit + a for j in 0..(res_b.rank() + 1).into() { self.vec_znx_big_sub_small_a(&mut res_big_tmp, 0, tmp_b.data(), j, &res_big, j); - self.vec_znx_big_normalize( - res_base2k, - res_b.data_mut(), - j, - s_base2k, - &res_big_tmp, - 0, - scratch_4, - ); + self.vec_znx_big_normalize(res_b.data_mut(), res_base2k, 0, j, &res_big_tmp, s_base2k, 0, scratch_4); } } } @@ -505,15 +468,7 @@ where let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1); for j in 0..(res.rank() + 1).into() { self.vec_znx_big_add_small_inplace(&mut res_big, j, f.data(), j); - self.vec_znx_big_normalize( - res_base2k, - res.data_mut(), - j, - ggsw_base2k, - &res_big, - j, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, j, &res_big, ggsw_base2k, j, scratch_1); } } @@ -544,15 +499,7 @@ where let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, &tmp, s, scratch_2); for j in 0..(res.rank() + 1).into() { self.vec_znx_big_add_small_inplace(&mut res_big, j, res.data(), j); - self.vec_znx_big_normalize( - res_base2k, - res.data_mut(), - j, - ggsw_base2k, - &res_big, - j, - scratch_2, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, j, &res_big, ggsw_base2k, j, scratch_2); } } @@ -574,15 +521,7 @@ where let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1); for j in 0..(res.rank() + 1).into() { self.vec_znx_big_add_small_inplace(&mut res_big, j, a.data(), j); - self.vec_znx_big_normalize( - res_base2k, - res.data_mut(), - j, - ggsw_base2k, - &res_big, - j, - scratch_1, - ); + self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, j, &res_big, ggsw_base2k, j, scratch_1); } } } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/key.rs index e4c2cf9..5fc9c28 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/key.rs @@ -73,10 +73,7 @@ where { Self { cbt: CircuitBootstrappingKey::alloc_from_infos(&infos.cbt_infos()), - ks_glwe: infos - .ks_glwe_infos() - .as_ref() - .map(GLWESwitchingKey::alloc_from_infos), + ks_glwe: infos.ks_glwe_infos().as_ref().map(GLWESwitchingKey::alloc_from_infos), ks_lwe: GLWEToLWEKey::alloc_from_infos(&infos.ks_lwe_infos()), } } @@ -131,15 +128,12 @@ where let mut sk_out: GLWESecret> = GLWESecret::alloc(sk_glwe.n(), key.rank_out()); sk_out.fill_ternary_prob(0.5, source_xe); key.encrypt_sk(self, sk_glwe, &sk_out, source_xa, source_xe, scratch); - res.ks_lwe - .encrypt_sk(self, sk_lwe, &sk_out, source_xa, source_xe, scratch); + res.ks_lwe.encrypt_sk(self, sk_lwe, &sk_out, source_xa, source_xe, scratch); } else { - res.ks_lwe - .encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + res.ks_lwe.encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); } - res.cbt - .encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + res.cbt.encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); } } @@ -224,10 +218,7 @@ where A: BDDKeyInfos, { let ks_glwe = if let Some(ks_glwe_infos) = &infos.ks_glwe_infos() { - Some(GLWESwitchingKeyPrepared::alloc_from_infos( - self, - ks_glwe_infos, - )) + Some(GLWESwitchingKeyPrepared::alloc_from_infos(self, ks_glwe_infos)) } else { None }; diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/mod.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/mod.rs index 38bb35e..17843ba 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/mod.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/mod.rs @@ -1,9 +1,19 @@ pub mod test_suite; #[cfg(test)] -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] mod fft64_ref; #[cfg(test)] -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] mod fft64_avx; diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/add.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/add.rs index 53f7a46..8f945a4 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/add.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/add.rs @@ -60,35 +60,12 @@ where let b: u32 = source.next_u32(); source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); // a + b - res.add( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.add(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); - assert_eq!( - res.decrypt(module, sk_glwe_prep, scratch.borrow()), - a.wrapping_add(b) - ); + assert_eq!(res.decrypt(module, sk_glwe_prep, scratch.borrow()), a.wrapping_add(b)); } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/and.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/and.rs index e122485..ee498dd 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/and.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/and.rs @@ -60,31 +60,11 @@ where let b: u32 = source.next_u32(); source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); - res.and( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.and(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); assert_eq!(res.decrypt(module, sk_glwe_prep, scratch.borrow()), a & b); } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/fheuint.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/fheuint.rs index eb13d30..0649a54 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/fheuint.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/fheuint.rs @@ -38,14 +38,7 @@ where for j in 0..3 { let a: u32 = 0x8483_8281; - a_enc.encrypt_sk( - module, - a, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc.encrypt_sk(module, a, sk, &mut source_xa, &mut source_xe, scratch.borrow()); a_enc.sext(module, j, keys, scratch.borrow()); @@ -57,14 +50,7 @@ where for j in 0..3 { let a: u32 = 0x4443_4241; - a_enc.encrypt_sk( - module, - a, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc.encrypt_sk(module, a, sk, &mut source_xa, &mut source_xe, scratch.borrow()); a_enc.sext(module, j, keys, scratch.borrow()); @@ -105,22 +91,8 @@ where let a: u32 = 0xFFFFFFFF; let b: u32 = 0xAABBCCDD; - b_enc.encrypt_sk( - module, - b, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - a_enc.encrypt_sk( - module, - a, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc.encrypt_sk(module, b, sk, &mut source_xa, &mut source_xe, scratch.borrow()); + a_enc.encrypt_sk(module, a, sk, &mut source_xa, &mut source_xe, scratch.borrow()); for dst in 0..4 { for src in 0..4 { @@ -162,22 +134,8 @@ where let a: u32 = 0xFFFFFFFF; let b: u32 = 0xAABBCCDD; - b_enc.encrypt_sk( - module, - b, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - a_enc.encrypt_sk( - module, - a, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc.encrypt_sk(module, b, sk, &mut source_xa, &mut source_xe, scratch.borrow()); + a_enc.encrypt_sk(module, a, sk, &mut source_xa, &mut source_xe, scratch.borrow()); for dst in 0..2 { for src in 0..2 { @@ -214,14 +172,7 @@ where let a: u32 = source_xa.next_u32(); - a_enc.encrypt_sk( - module, - a, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc.encrypt_sk(module, a, sk, &mut source_xa, &mut source_xe, scratch.borrow()); for i in 0..32 { a_enc.get_bit_glwe(module, i, &mut c_enc, keys, scratch.borrow()); diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs index 0d9272f..c01ce39 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs @@ -68,24 +68,13 @@ where let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_res_infos); let mut scalar: ScalarZnx> = ScalarZnx::alloc(module.n(), 1); - scalar - .raw_mut() - .iter_mut() - .enumerate() - .for_each(|(i, x)| *x = i as i64); + scalar.raw_mut().iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); let k: u32 = source.next_u32(); let mut k_enc_prep: FheUintPrepared, u32, BE> = FheUintPrepared::, u32, BE>::alloc_from_infos(module, &ggsw_k_infos); - k_enc_prep.encrypt_sk( - module, - k, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + k_enc_prep.encrypt_sk(module, k, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); let base: [usize; 2] = [module.log_n() >> 1, module.log_n() - (module.log_n() >> 1)]; @@ -134,16 +123,9 @@ where for row in 0..res.dnum().as_usize() { for col in 0..res.rank().as_usize() + 1 { assert!( - res.noise( - module, - row, - col, - &scalar_want, - sk_glwe_prep, - scratch.borrow() - ) - .std() - .log2() + res.noise(module, row, col, &scalar_want, sk_glwe_prep, scratch.borrow()) + .std() + .log2() <= max_noise(col) ) } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs index 7797713..da03e02 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs @@ -73,14 +73,7 @@ where let mut k_enc_prep: FheUintPrepared, u32, BE> = FheUintPrepared::, u32, BE>::alloc_from_infos(module, &ggsw_infos); - k_enc_prep.encrypt_sk( - module, - k, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + k_enc_prep.encrypt_sk(module, k, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); let base: [usize; 2] = [module.log_n() >> 1, module.log_n() - (module.log_n() >> 1)]; diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs index 0803043..6366567 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs @@ -70,14 +70,7 @@ where let mut k_enc_prep: FheUintPrepared, u32, BE> = FheUintPrepared::, u32, BE>::alloc_from_infos(module, &ggsw_infos); - k_enc_prep.encrypt_sk( - module, - k, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + k_enc_prep.encrypt_sk(module, k, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); let digit = 5; let mask: u32 = (1 << digit) - 1; @@ -97,14 +90,7 @@ where for value in data.iter().take(1 << digit) { pt.encode_coeff_i64(*value, TorusPrecision(base2k.as_u32()), 0); let mut ct = GLWE::alloc_from_infos(&glwe_infos); - ct.encrypt_sk( - module, - &pt, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct.encrypt_sk(module, &pt, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); cts.push(ct); } @@ -117,14 +103,7 @@ where // How many bits to take let bit_size: usize = (32 - bit_start).min(digit); - module.glwe_blind_selection( - &mut res, - cts_map, - &k_enc_prep, - bit_start, - bit_size, - scratch.borrow(), - ); + module.glwe_blind_selection(&mut res, cts_map, &k_enc_prep, bit_start, bit_size, scratch.borrow()); res.decrypt(module, &mut pt, sk_glwe_prep, scratch.borrow()); @@ -132,10 +111,7 @@ where if !idx.is_multiple_of(3) { assert_eq!(0, pt.decode_coeff_i64(TorusPrecision(base2k.as_u32()), 0)); } else { - assert_eq!( - data[idx], - pt.decode_coeff_i64(TorusPrecision(base2k.as_u32()), 0) - ); + assert_eq!(data[idx], pt.decode_coeff_i64(TorusPrecision(base2k.as_u32()), 0)); } bit_start += digit; diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/mod.rs index 4310c49..38e9db8 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -116,14 +116,7 @@ impl TestContext { sk_lwe.fill_binary_block(block_size as usize, &mut source_xs); let bdd_key_infos: BDDKeyLayout = TEST_BDD_KEY_LAYOUT; let mut bdd_key: BDDKey, BRA> = BDDKey::alloc_from_infos(&bdd_key_infos); - bdd_key.encrypt_sk( - &module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + bdd_key.encrypt_sk(&module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); let mut bdd_key_prepared: BDDKeyPrepared, BRA, BE> = BDDKeyPrepared::alloc_from_infos(&module, &bdd_key_infos); bdd_key_prepared.prepare(&module, &bdd_key, scratch.borrow()); diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/or.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/or.rs index ac9e3d4..79da8f8 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/or.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/or.rs @@ -60,31 +60,11 @@ where let b: u32 = source.next_u32(); source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); - res.or( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.or(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); assert_eq!(res.decrypt(module, sk_glwe_prep, scratch.borrow()), a | b); } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/prepare.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/prepare.rs index 382669e..6945bfc 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/prepare.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/prepare.rs @@ -54,14 +54,7 @@ where // GLWE(value) let mut c_enc: FheUint, u32> = FheUint::alloc_from_infos(&glwe_infos); let value: u32 = source.next_u32(); - c_enc.encrypt_sk( - module, - value, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + c_enc.encrypt_sk(module, value, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); // GGSW(0) let mut c_enc_prep_debug: FheUintPreparedDebug, u32> = diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sll.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sll.rs index dd2bd06..428dada 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sll.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sll.rs @@ -60,34 +60,11 @@ where let b: u32 = source.next_u32() & 15; source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); - res.sll( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.sll(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); - assert_eq!( - res.decrypt(module, sk_glwe_prep, scratch.borrow()), - a.wrapping_shl(b) - ); + assert_eq!(res.decrypt(module, sk_glwe_prep, scratch.borrow()), a.wrapping_shl(b)); } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/slt.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/slt.rs index 7d5f01f..f6809d1 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/slt.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/slt.rs @@ -60,32 +60,12 @@ where let b: u32 = source.next_u32(); source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); // d + a - res.slt( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.slt(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); assert_eq!( res.decrypt(module, sk_glwe_prep, scratch.borrow()), diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sltu.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sltu.rs index 88146e6..9f963ce 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sltu.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sltu.rs @@ -60,35 +60,12 @@ where let b: u32 = source.next_u32(); source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); // d + a - res.sltu( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.sltu(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); - assert_eq!( - res.decrypt(module, sk_glwe_prep, scratch.borrow()), - (a < b) as u32 - ); + assert_eq!(res.decrypt(module, sk_glwe_prep, scratch.borrow()), (a < b) as u32); } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sra.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sra.rs index 7ff0b17..9765e88 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sra.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sra.rs @@ -60,34 +60,11 @@ where let b: u32 = source.next_u32() & 15; source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); - res.sra( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.sra(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); - assert_eq!( - res.decrypt(module, sk_glwe_prep, scratch.borrow()), - ((a as i32) >> b) as u32 - ); + assert_eq!(res.decrypt(module, sk_glwe_prep, scratch.borrow()), ((a as i32) >> b) as u32); } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/srl.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/srl.rs index 03e00b1..0418b75 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/srl.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/srl.rs @@ -60,31 +60,11 @@ where let b: u32 = source.next_u32() & 15; source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); - res.srl( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.srl(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); assert_eq!(res.decrypt(module, sk_glwe_prep, scratch.borrow()), a >> b); } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sub.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sub.rs index e765ed3..315037f 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sub.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/sub.rs @@ -60,34 +60,11 @@ where let b: u32 = source.next_u32(); source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); - res.sub( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.sub(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); - assert_eq!( - res.decrypt(module, sk_glwe_prep, scratch.borrow()), - a.wrapping_sub(b) - ); + assert_eq!(res.decrypt(module, sk_glwe_prep, scratch.borrow()), a.wrapping_sub(b)); } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/swap.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/swap.rs index 150ac36..36353d4 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/swap.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/swap.rs @@ -45,34 +45,13 @@ where let mut a_enc: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); let mut b_enc: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); - a_enc.encrypt_sk( - module, - a, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc.encrypt_sk(module, a, sk, &mut source_xa, &mut source_xe, scratch.borrow()); - b_enc.encrypt_sk( - module, - b, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc.encrypt_sk(module, b, sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut pt: ScalarZnx> = ScalarZnx::alloc(module.n(), 1); pt.raw_mut()[0] = bit; - s.encrypt_sk( - module, - &pt, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + s.encrypt_sk(module, &pt, sk, &mut source_xa, &mut source_xe, scratch.borrow()); s_prepared.prepare(module, &s, scratch.borrow()); module.cswap(&mut a_enc, &mut b_enc, &s_prepared, scratch.borrow()); @@ -106,35 +85,18 @@ where let mut data_enc: Vec, u32>> = (0..data.len()) .map(|i| { let mut ct: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); - ct.encrypt_sk( - module, - data[i], - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct.encrypt_sk(module, data[i], sk, &mut source_xa, &mut source_xe, scratch.borrow()); ct }) .collect_vec(); for idx in 0..data.len() as u32 { let mut idx_enc = FheUintPrepared::alloc_from_infos(module, &ggsw_infos); - idx_enc.encrypt_sk( - module, - idx, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + idx_enc.encrypt_sk(module, idx, sk, &mut source_xa, &mut source_xe, scratch.borrow()); module.glwe_blind_retrieval_statefull(&mut data_enc, &idx_enc, 0, 5, scratch.borrow()); - assert_eq!( - data[idx as usize], - data_enc[0].decrypt(module, sk, scratch.borrow()) - ); + assert_eq!(data[idx as usize], data_enc[0].decrypt(module, sk, scratch.borrow())); module.glwe_blind_retrieval_statefull_rev(&mut data_enc, &idx_enc, 0, 5, scratch.borrow()); @@ -166,14 +128,7 @@ where let data_enc: Vec, u32>> = (0..data.len()) .map(|i| { let mut ct: FheUint, u32> = FheUint::, u32>::alloc_from_infos(&glwe_infos); - ct.encrypt_sk( - module, - data[i], - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct.encrypt_sk(module, data[i], sk, &mut source_xa, &mut source_xe, scratch.borrow()); ct }) .collect_vec(); @@ -182,28 +137,11 @@ where for idx in 0..data.len() as u32 { let offset = 2; let mut idx_enc: FheUintPrepared, u32, BE> = FheUintPrepared::alloc_from_infos(module, &ggsw_infos); - idx_enc.encrypt_sk( - module, - idx << offset, - sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + idx_enc.encrypt_sk(module, idx << offset, sk, &mut source_xa, &mut source_xe, scratch.borrow()); let mut res: FheUint, u32> = FheUint::alloc_from_infos(&glwe_infos); - retriever.retrieve( - module, - &mut res, - &data_enc, - &idx_enc, - offset, - scratch.borrow(), - ); + retriever.retrieve(module, &mut res, &data_enc, &idx_enc, offset, scratch.borrow()); - assert_eq!( - data[idx as usize], - res.decrypt(module, sk, scratch.borrow()) - ); + assert_eq!(data[idx as usize], res.decrypt(module, sk, scratch.borrow())); } } diff --git a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/xor.rs b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/xor.rs index 8b58c18..3991544 100644 --- a/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/xor.rs +++ b/poulpy-schemes/src/bin_fhe/bdd_arithmetic/tests/test_suite/xor.rs @@ -60,31 +60,11 @@ where let b: u32 = source.next_u32(); source.fill_bytes(&mut scratch.borrow().data); - a_enc_prep.encrypt_sk( - module, - a, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + a_enc_prep.encrypt_sk(module, a, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); source.fill_bytes(&mut scratch.borrow().data); - b_enc_prep.encrypt_sk( - module, - b, - sk_glwe_prep, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + b_enc_prep.encrypt_sk(module, b, sk_glwe_prep, &mut source_xa, &mut source_xe, scratch.borrow()); - res.xor( - module, - &a_enc_prep, - &b_enc_prep, - bdd_key_prepared, - scratch.borrow(), - ); + res.xor(module, &a_enc_prep, &b_enc_prep, bdd_key_prepared, scratch.borrow()); assert_eq!(res.decrypt(module, sk_glwe_prep, scratch.borrow()), a ^ b); } diff --git a/poulpy-schemes/src/bin_fhe/blind_rotation/algorithms/cggi/algorithm.rs b/poulpy-schemes/src/bin_fhe/blind_rotation/algorithms/cggi/algorithm.rs index 18b6f05..d56a7d5 100644 --- a/poulpy-schemes/src/bin_fhe/blind_rotation/algorithms/cggi/algorithm.rs +++ b/poulpy-schemes/src/bin_fhe/blind_rotation/algorithms/cggi/algorithm.rs @@ -184,11 +184,7 @@ fn execute_block_binary_extended( let block_size: usize = brk.block_size(); - izip!( - a.chunks_exact(block_size), - brk.data.chunks_exact(block_size) - ) - .for_each(|(ai, ski)| { + izip!(a.chunks_exact(block_size), brk.data.chunks_exact(block_size)).for_each(|(ai, ski)| { for i in 0..extension_factor { for j in 0..cols { module.vec_znx_dft_apply(1, 0, &mut acc_dft[i], j, &acc[i], j); @@ -258,7 +254,7 @@ fn execute_block_binary_extended( (0..cols).for_each(|i| { module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7); module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i); - module.vec_znx_big_normalize(base2k, &mut acc[j], i, base2k, &acc_add_big, 0, scratch7); + module.vec_znx_big_normalize(&mut acc[j], base2k, 0, i, &acc_add_big, base2k, 0, scratch7); }); }); } @@ -306,12 +302,7 @@ fn execute_block_binary( let cols: usize = (out_mut.rank() + 1).into(); - mod_switch_2n( - 2 * lut.domain_size(), - &mut lwe_2n, - &lwe_ref, - lut.rotation_direction(), - ); + mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref, lut.rotation_direction()); let a: &[i64] = &lwe_2n[1..]; let b: i64 = lwe_2n[0]; @@ -337,11 +328,7 @@ fn execute_block_binary( panic!("invalid key: x_pow_a has not been initialized") } - izip!( - a.chunks_exact(block_size), - brk.data.chunks_exact(block_size) - ) - .for_each(|(ai, ski)| { + izip!(a.chunks_exact(block_size), brk.data.chunks_exact(block_size)).for_each(|(ai, ski)| { for j in 0..cols { module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, out_mut.data_mut(), j); module.vec_znx_dft_zero(&mut acc_add_dft, j) @@ -367,15 +354,7 @@ fn execute_block_binary( (0..cols).for_each(|i| { module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft, i, scratch_5); module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, out_mut.data_mut(), i); - module.vec_znx_big_normalize( - base2k, - out_mut.data_mut(), - i, - base2k, - &acc_add_big, - 0, - scratch_5, - ); + module.vec_znx_big_normalize(out_mut.data_mut(), base2k, 0, i, &acc_add_big, base2k, 0, scratch_5); }); } }); @@ -397,13 +376,7 @@ fn execute_standard( { #[cfg(debug_assertions)] { - assert_eq!( - res.n(), - brk.n(), - "res.n(): {} != brk.n(): {}", - res.n(), - brk.n() - ); + assert_eq!(res.n(), brk.n(), "res.n(): {} != brk.n(): {}", res.n(), brk.n()); assert_eq!( lut.domain_size(), brk.n_glwe().as_usize(), @@ -431,12 +404,7 @@ fn execute_standard( let mut out_mut: GLWE<&mut [u8]> = res.to_mut(); let lwe_ref: LWE<&[u8]> = lwe.to_ref(); - mod_switch_2n( - 2 * lut.domain_size(), - &mut lwe_2n, - &lwe_ref, - lut.rotation_direction(), - ); + mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref, lut.rotation_direction()); let a: &[i64] = &lwe_2n[1..]; let b: i64 = lwe_2n[0]; diff --git a/poulpy-schemes/src/bin_fhe/blind_rotation/algorithms/cggi/key.rs b/poulpy-schemes/src/bin_fhe/blind_rotation/algorithms/cggi/key.rs index 4877334..661b234 100644 --- a/poulpy-schemes/src/bin_fhe/blind_rotation/algorithms/cggi/key.rs +++ b/poulpy-schemes/src/bin_fhe/blind_rotation/algorithms/cggi/key.rs @@ -20,9 +20,7 @@ impl BlindRotationKeyFactory for BlindRotationKey { A: BlindRotationKeyInfos, { BlindRotationKey { - keys: (0..infos.n_lwe().as_usize()) - .map(|_| GGSW::alloc_from_infos(infos)) - .collect(), + keys: (0..infos.n_lwe().as_usize()).map(|_| GGSW::alloc_from_infos(infos)).collect(), dist: Distribution::NONE, _phantom: PhantomData, } diff --git a/poulpy-schemes/src/bin_fhe/blind_rotation/layouts/key.rs b/poulpy-schemes/src/bin_fhe/blind_rotation/layouts/key.rs index 1fc57a8..07bb002 100644 --- a/poulpy-schemes/src/bin_fhe/blind_rotation/layouts/key.rs +++ b/poulpy-schemes/src/bin_fhe/blind_rotation/layouts/key.rs @@ -137,9 +137,7 @@ impl fmt::Display for BlindRotationKey FillUniform for BlindRotationKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.keys - .iter_mut() - .for_each(|key| key.fill_uniform(log_bound, source)); + self.keys.iter_mut().for_each(|key| key.fill_uniform(log_bound, source)); } } diff --git a/poulpy-schemes/src/bin_fhe/blind_rotation/layouts/key_compressed.rs b/poulpy-schemes/src/bin_fhe/blind_rotation/layouts/key_compressed.rs index 6ec1fee..c127e29 100644 --- a/poulpy-schemes/src/bin_fhe/blind_rotation/layouts/key_compressed.rs +++ b/poulpy-schemes/src/bin_fhe/blind_rotation/layouts/key_compressed.rs @@ -71,9 +71,7 @@ impl fmt::Display for BlindRotationKeyCompre impl FillUniform for BlindRotationKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.keys - .iter_mut() - .for_each(|key| key.fill_uniform(log_bound, source)); + self.keys.iter_mut().for_each(|key| key.fill_uniform(log_bound, source)); } } diff --git a/poulpy-schemes/src/bin_fhe/blind_rotation/lut.rs b/poulpy-schemes/src/bin_fhe/blind_rotation/lut.rs index 74c0441..0a01dd2 100644 --- a/poulpy-schemes/src/bin_fhe/blind_rotation/lut.rs +++ b/poulpy-schemes/src/bin_fhe/blind_rotation/lut.rs @@ -163,13 +163,7 @@ impl DivRound for usize { #[allow(dead_code)] fn max_bit_size(vec: &[i64]) -> u32 { vec.iter() - .map(|&v| { - if v == 0 { - 0 - } else { - v.unsigned_abs().ilog2() + 1 - } - }) + .map(|&v| if v == 0 { 0 } else { v.unsigned_abs().ilog2() + 1 }) .max() .unwrap_or(0) } @@ -192,10 +186,7 @@ where let base2k: usize = res.base2k.into(); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - self.vec_znx_normalize_tmp_bytes() - .max(res.domain_size() << 3), - ); + let mut scratch: ScratchOwned = ScratchOwned::alloc(self.vec_znx_normalize_tmp_bytes().max(res.domain_size() << 3)); // Get the number minimum limb to store the message modulus let limbs: usize = k.div_ceil(base2k); @@ -222,7 +213,7 @@ where // If LUT size > TakeScalarZnx let domain_size: usize = res.domain_size(); - let size: usize = res.k.div_ceil(res.base2k) as usize; + let size: usize = res.k.as_usize().div_ceil(base2k); // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) let mut lut_full: VecZnx> = VecZnx::alloc(domain_size, 1, size); @@ -231,16 +222,15 @@ where let step: usize = domain_size.div_round(f_len); - f.iter().enumerate().for_each(|(i, fi)| { + for (i, fi) in f.iter().enumerate() { let start: usize = i * step; let end: usize = start + step; lut_at[start..end].fill(fi * scale); - }); + } let drift: usize = step >> 1; // Rotates half the step to the left - if res.extension_factor() > 1 { let (tmp, _) = scratch.borrow().take_slice(lut_full.n()); diff --git a/poulpy-schemes/src/bin_fhe/blind_rotation/tests/mod.rs b/poulpy-schemes/src/bin_fhe/blind_rotation/tests/mod.rs index b5fa520..c89bfc4 100644 --- a/poulpy-schemes/src/bin_fhe/blind_rotation/tests/mod.rs +++ b/poulpy-schemes/src/bin_fhe/blind_rotation/tests/mod.rs @@ -1,9 +1,19 @@ #[cfg(test)] -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] mod fft64_ref; #[cfg(test)] -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] mod fft64_avx; #[cfg(test)] diff --git a/poulpy-schemes/src/bin_fhe/blind_rotation/tests/test_suite/generic_blind_rotation.rs b/poulpy-schemes/src/bin_fhe/blind_rotation/tests/test_suite/generic_blind_rotation.rs index 7686325..9675f92 100644 --- a/poulpy-schemes/src/bin_fhe/blind_rotation/tests/test_suite/generic_blind_rotation.rs +++ b/poulpy-schemes/src/bin_fhe/blind_rotation/tests/test_suite/generic_blind_rotation.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow}, + api::{ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow}, layouts::{Backend, Scratch, ScratchOwned, ZnxView}, source::Source, }; @@ -24,7 +24,8 @@ pub fn test_blind_rotation( block_size: usize, extension_factor: usize, ) where - M: BlindRotationKeyEncryptSk + M: ModuleN + + BlindRotationKeyEncryptSk + BlindRotationKeyPreparedFactory + BlindRotationExecute + LookupTableFactory @@ -111,22 +112,12 @@ pub fn test_blind_rotation( pt_lwe.encode_i64(x, (log_message_modulus + 1).into()); - lwe.encrypt_sk( - module, - &pt_lwe, - &sk_lwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, scratch.borrow()); let f = |x: i64| -> i64 { 2 * x + 1 }; let mut f_vec: Vec = vec![0i64; message_modulus]; - f_vec - .iter_mut() - .enumerate() - .for_each(|(i, x)| *x = f(i as i64)); + f_vec.iter_mut().enumerate().for_each(|(i, x)| *x = f(i as i64)); let lut_infos = LookUpTableLayout { n: module.n().into(), @@ -151,20 +142,10 @@ pub fn test_blind_rotation( let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space - mod_switch_2n( - 2 * lut.domain_size(), - &mut lwe_2n, - &lwe.to_ref(), - lut.rotation_direction(), - ); + mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref(), lut.rotation_direction()); - let pt_want: i64 = (lwe_2n[0] - + lwe_2n[1..] - .iter() - .zip(sk_lwe.raw()) - .map(|(x, y)| x * y) - .sum::()) - & (2 * lut.domain_size() - 1) as i64; + let pt_want: i64 = + (lwe_2n[0] + lwe_2n[1..].iter().zip(sk_lwe.raw()).map(|(x, y)| x * y).sum::()) & (2 * lut.domain_size() - 1) as i64; lut.rotate(module, pt_want); diff --git a/poulpy-schemes/src/bin_fhe/blind_rotation/tests/test_suite/generic_lut.rs b/poulpy-schemes/src/bin_fhe/blind_rotation/tests/test_suite/generic_lut.rs index f933913..f58654f 100644 --- a/poulpy-schemes/src/bin_fhe/blind_rotation/tests/test_suite/generic_lut.rs +++ b/poulpy-schemes/src/bin_fhe/blind_rotation/tests/test_suite/generic_lut.rs @@ -16,9 +16,7 @@ where let log_scale: usize = base2k + 1; let mut f: Vec = vec![0i64; message_modulus]; - f.iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i as i64) - 8); + f.iter_mut().enumerate().for_each(|(i, x)| *x = (i as i64) - 8); let lut_infos: LookUpTableLayout = LookUpTableLayout { n: module.n().into(), @@ -57,9 +55,7 @@ where let log_scale: usize = base2k + 1; let mut f: Vec = vec![0i64; message_modulus]; - f.iter_mut() - .enumerate() - .for_each(|(i, x)| *x = (i as i64) - 8); + f.iter_mut().enumerate().for_each(|(i, x)| *x = (i as i64) - 8); let lut_infos: LookUpTableLayout = LookUpTableLayout { n: module.n().into(), diff --git a/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/circuit.rs index f47dcc2..cca221f 100644 --- a/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/circuit.rs @@ -95,15 +95,7 @@ impl CircuitBootstrappingKeyPre R: GGSWToMut + GGSWInfos, L: LWEToRef + LWEInfos, { - module.circuit_bootstrapping_execute_to_exponent( - log_gap_out, - res, - lwe, - self, - log_domain, - extension_factor, - scratch, - ); + module.circuit_bootstrapping_execute_to_exponent(log_gap_out, res, lwe, self, log_domain, extension_factor, scratch); } } @@ -144,14 +136,9 @@ where rank_out: res_infos.rank(), }; - self.blind_rotation_execute_tmp_bytes( - block_size, - extension_factor, - res_infos, - &cbt_infos.brk_infos(), - ) - .max(self.glwe_trace_tmp_bytes(res_infos, res_infos, &cbt_infos.atk_infos())) - .max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos())) + self.blind_rotation_execute_tmp_bytes(block_size, extension_factor, res_infos, &cbt_infos.brk_infos()) + .max(self.glwe_trace_tmp_bytes(res_infos, res_infos, &cbt_infos.atk_infos())) + .max(self.ggsw_from_gglwe_tmp_bytes(res_infos, &cbt_infos.tsk_infos())) + GLWE::bytes_of_from_infos(res_infos) + GGLWE::bytes_of_from_infos(&gglwe_infos) } @@ -173,17 +160,7 @@ where scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key) ); - circuit_bootstrap_core( - false, - self, - 0, - res, - lwe, - log_domain, - extension_factor, - key, - scratch, - ); + circuit_bootstrap_core(false, self, 0, res, lwe, log_domain, extension_factor, key, scratch); } fn circuit_bootstrapping_execute_to_exponent( @@ -204,17 +181,7 @@ where scratch.available() >= self.circuit_bootstrapping_execute_tmp_bytes(key.block_size(), extension_factor, res, key) ); - circuit_bootstrap_core( - true, - self, - log_gap_out, - res, - lwe, - log_domain, - extension_factor, - key, - scratch, - ); + circuit_bootstrap_core(true, self, log_gap_out, res, lwe, log_domain, extension_factor, key, scratch); } } @@ -254,7 +221,7 @@ pub fn circuit_bootstrap_core( assert_eq!(res.n(), key.brk.n()); - let base2k_res: usize = res.base2k().as_usize(); + let res_base2k: usize = res.base2k().as_usize(); let dnum_res: usize = res.dnum().into(); let alpha: usize = dnum_res.next_power_of_two(); @@ -263,12 +230,12 @@ pub fn circuit_bootstrap_core( if to_exponent { (0..dnum_res).for_each(|i| { - f[i] = 1 << (base2k_res * (dnum_res - 1 - i)); + f[i] = 1 << (res_base2k * (dnum_res - 1 - i)); }); } else { (0..1 << log_domain).for_each(|j| { (0..dnum_res).for_each(|i| { - f[j * alpha + i] = j as i64 * (1 << (base2k_res * (dnum_res - 1 - i))); + f[j * alpha + i] = j as i64 * (1 << (res_base2k * (dnum_res - 1 - i))); }); }); } @@ -276,13 +243,13 @@ pub fn circuit_bootstrap_core( let lut_infos: LookUpTableLayout = LookUpTableLayout { n: module.n().into(), extension_factor, - k: (base2k_res * dnum_res).into(), + k: (res_base2k * dnum_res).into(), base2k: key.brk.base2k(), }; // Lut precision, basically must be able to hold the decomposition power basis of the GGSW let mut lut: LookupTable = LookupTable::alloc(&lut_infos); - lut.set(module, &f, base2k_res * dnum_res); + lut.set(module, &f, res_base2k * dnum_res); if to_exponent { lut.set_rotation_direction(LookUpTableRotationDirection::Right); @@ -309,8 +276,7 @@ pub fn circuit_bootstrap_core( // Execute blind rotation over BRK layout and returns result over ATK layout { let (mut res_glwe_brk_layout, scratch_2) = scratch_1.take_glwe(glwe_brk_layout); - key.brk - .execute(module, &mut res_glwe_brk_layout, lwe, &lut, scratch_2); + key.brk.execute(module, &mut res_glwe_brk_layout, lwe, &lut, scratch_2); if res_glwe_brk_layout.base2k() == res_glwe_atk_layout.base2k() { module.glwe_copy(&mut res_glwe_atk_layout, &res_glwe_brk_layout); @@ -376,13 +342,7 @@ fn post_process( // First partial trace, vanishes all coefficients which are not multiples of gap_in // [1, 1, 1, 1, 0, 0, 0, ..., 0, 0, -1, -1, -1, -1] -> [1, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0] - module.glwe_trace( - &mut a_trace, - module.log_n() - log_gap_in + 1, - a, - auto_keys, - scratch_1, - ); + module.glwe_trace(&mut a_trace, module.log_n() - log_gap_in + 1, a, auto_keys, scratch_1); let steps: usize = 1 << log_domain; diff --git a/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/key.rs index ae64b8c..135d82c 100644 --- a/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/key.rs @@ -167,17 +167,10 @@ where let mut sk_glwe_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(self, brk_infos.rank()); sk_glwe_prepared.prepare(self, sk_glwe); - res.brk.encrypt_sk( - self, - &sk_glwe_prepared, - sk_lwe, - source_xa, - source_xe, - scratch, - ); + res.brk + .encrypt_sk(self, &sk_glwe_prepared, sk_lwe, source_xa, source_xe, scratch); - res.tsk - .encrypt_sk(self, sk_glwe, source_xa, source_xe, scratch); + res.tsk.encrypt_sk(self, sk_glwe, source_xa, source_xe, scratch); } } diff --git a/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs b/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs index 665d13f..c0e2c2a 100644 --- a/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs +++ b/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs @@ -46,11 +46,11 @@ where Scratch: ScratchTakeCore, { let n_glwe: usize = module.n(); - let base2k_res: usize = 15; + let res_base2k: usize = 15; let base2k_lwe: usize = 14; let base2k_brk: usize = 13; - let base2k_tsk: usize = 12; - let base2k_atk: usize = 11; + let tsk_base2k: usize = 12; + let a_base2ktk: usize = 11; let extension_factor: usize = 1; let rank: usize = 1; @@ -59,16 +59,16 @@ where let k_lwe_ct: usize = 22; let block_size: usize = 7; - let k_ggsw_res: usize = 4 * base2k_res; + let k_ggsw_res: usize = 4 * res_base2k; let rows_ggsw_res: usize = 3; let k_brk: usize = k_ggsw_res + base2k_brk; let rows_brk: usize = 4; - let k_atk: usize = k_ggsw_res + base2k_tsk; + let k_atk: usize = k_ggsw_res + tsk_base2k; let rows_atk: usize = 4; - let k_tsk: usize = k_ggsw_res + base2k_atk; + let k_tsk: usize = k_ggsw_res + a_base2ktk; let rows_tsk: usize = 4; let lwe_infos: LWELayout = LWELayout { @@ -88,7 +88,7 @@ where }, atk_layout: GLWEAutomorphismKeyLayout { n: n_glwe.into(), - base2k: base2k_atk.into(), + base2k: a_base2ktk.into(), k: k_atk.into(), dnum: rows_atk.into(), rank: rank.into(), @@ -96,7 +96,7 @@ where }, tsk_layout: GGLWEToGGSWKeyLayout { n: n_glwe.into(), - base2k: base2k_tsk.into(), + base2k: tsk_base2k.into(), k: k_tsk.into(), dnum: rows_tsk.into(), dsize: Dsize(1), @@ -106,7 +106,7 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), - base2k: base2k_res.into(), + base2k: res_base2k.into(), k: k_ggsw_res.into(), dnum: rows_ggsw_res.into(), dsize: Dsize(1), @@ -136,28 +136,14 @@ where println!("pt_lwe: {pt_lwe}"); let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); - ct_lwe.encrypt_sk( - module, - &pt_lwe, - &sk_lwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, scratch.borrow()); let now: Instant = Instant::now(); let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); println!("CBT-ALLOC: {} ms", now.elapsed().as_millis()); let now: Instant = Instant::now(); - cbt_key.encrypt_sk( - module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + cbt_key.encrypt_sk(module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); println!("CBT-ENCRYPT: {} ms", now.elapsed().as_millis()); let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); @@ -183,33 +169,21 @@ where // X^{data * 2^log_gap_out} let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n_glwe, 1); pt_ggsw.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace( - data * (1 << log_gap_out), - &mut pt_ggsw.as_vec_znx_mut(), - 0, - scratch.borrow(), - ); + module.vec_znx_rotate_inplace(data * (1 << log_gap_out), &mut pt_ggsw.as_vec_znx_mut(), 0, scratch.borrow()); for row in 0..res.dnum().as_usize() { for col in 0..res.rank().as_usize() + 1 { println!( "row:{row} col:{col} -> {}", - res.noise( - module, - row, - col, - &pt_ggsw, - &sk_glwe_prepared, - scratch.borrow() - ) - .std() - .log2() + res.noise(module, row, col, &pt_ggsw, &sk_glwe_prepared, scratch.borrow()) + .std() + .log2() ) } } let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&ggsw_infos); let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); - pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k_res - 2); + pt_glwe.data.at_mut(0, 0)[0] = 1 << (res_base2k - 2); ct_glwe.encrypt_sk( module, @@ -253,11 +227,11 @@ where Scratch: ScratchTakeCore, { let n_glwe: usize = module.n(); - let base2k_res: usize = 15; + let res_base2k: usize = 15; let base2k_lwe: usize = 14; let base2k_brk: usize = 13; - let base2k_tsk: usize = 12; - let base2k_atk: usize = 11; + let tsk_base2k: usize = 12; + let a_base2ktk: usize = 11; let extension_factor: usize = 1; let rank: usize = 1; @@ -266,16 +240,16 @@ where let k_lwe_ct: usize = 13; let block_size: usize = 7; - let k_ggsw_res: usize = 4 * base2k_res; + let k_ggsw_res: usize = 4 * res_base2k; let rows_ggsw_res: usize = 3; let k_brk: usize = k_ggsw_res + base2k_brk; let rows_brk: usize = 4; - let k_atk: usize = k_ggsw_res + base2k_tsk; + let k_atk: usize = k_ggsw_res + tsk_base2k; let rows_atk: usize = 4; - let k_tsk: usize = k_ggsw_res + base2k_atk; + let k_tsk: usize = k_ggsw_res + a_base2ktk; let rows_tsk: usize = 4; let lwe_infos: LWELayout = LWELayout { @@ -295,7 +269,7 @@ where }, atk_layout: GLWEAutomorphismKeyLayout { n: n_glwe.into(), - base2k: base2k_atk.into(), + base2k: a_base2ktk.into(), k: k_atk.into(), dnum: rows_atk.into(), rank: rank.into(), @@ -303,7 +277,7 @@ where }, tsk_layout: GGLWEToGGSWKeyLayout { n: n_glwe.into(), - base2k: base2k_tsk.into(), + base2k: tsk_base2k.into(), k: k_tsk.into(), dnum: rows_tsk.into(), dsize: Dsize(1), @@ -313,7 +287,7 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), - base2k: base2k_res.into(), + base2k: res_base2k.into(), k: k_ggsw_res.into(), dnum: rows_ggsw_res.into(), dsize: Dsize(1), @@ -343,28 +317,14 @@ where println!("pt_lwe: {pt_lwe}"); let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); - ct_lwe.encrypt_sk( - module, - &pt_lwe, - &sk_lwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, scratch.borrow()); let now: Instant = Instant::now(); let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); println!("CBT-ALLOC: {} ms", now.elapsed().as_millis()); let now: Instant = Instant::now(); - cbt_key.encrypt_sk( - module, - &sk_lwe, - &sk_glwe, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); + cbt_key.encrypt_sk(module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow()); println!("CBT-ENCRYPT: {} ms", now.elapsed().as_millis()); let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); @@ -374,14 +334,7 @@ where cbt_prepared.prepare(module, &cbt_key, scratch.borrow()); let now: Instant = Instant::now(); - cbt_prepared.execute_to_constant( - module, - &mut res, - &ct_lwe, - k_lwe_pt, - extension_factor, - scratch.borrow(), - ); + cbt_prepared.execute_to_constant(module, &mut res, &ct_lwe, k_lwe_pt, extension_factor, scratch.borrow()); println!("CBT: {} ms", now.elapsed().as_millis()); // X^{data * 2^log_gap_out} @@ -392,23 +345,16 @@ where for col in 0..res.rank().as_usize() + 1 { println!( "row:{row} col:{col} -> {}", - res.noise( - module, - row, - col, - &pt_ggsw, - &sk_glwe_prepared, - scratch.borrow() - ) - .std() - .log2() + res.noise(module, row, col, &pt_ggsw, &sk_glwe_prepared, scratch.borrow()) + .std() + .log2() ) } } let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&ggsw_infos); let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); - pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k_res - k_lwe_pt - 1); + pt_glwe.data.at_mut(0, 0)[0] = 1 << (res_base2k - k_lwe_pt - 1); ct_glwe.encrypt_sk( module, diff --git a/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/tests/mod.rs b/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/tests/mod.rs index b685dbd..be57640 100644 --- a/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/tests/mod.rs +++ b/poulpy-schemes/src/bin_fhe/circuit_bootstrapping/tests/mod.rs @@ -1,9 +1,19 @@ pub mod circuit_bootstrapping; #[cfg(test)] -#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] +#[cfg(not(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +)))] mod fft64_ref; #[cfg(test)] -#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] +#[cfg(all( + feature = "enable-avx", + target_arch = "x86_64", + target_feature = "avx2", + target_feature = "fma" +))] mod fft64_avx; diff --git a/rustfmt.toml b/rustfmt.toml index 1467cc6..dfc0980 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -2,78 +2,8 @@ max_width = 130 hard_tabs = false tab_spaces = 4 newline_style = "Auto" -indent_style = "Block" use_small_heuristics = "Default" -fn_call_width = 60 -attr_fn_like_width = 100 -struct_lit_width = 18 -struct_variant_width = 35 -array_width = 60 -chain_width = 60 -single_line_if_else_max_width = 50 -single_line_let_else_max_width = 50 wrap_comments = false format_code_in_doc_comments = true -doc_comment_code_block_width = 100 -comment_width = 80 -normalize_comments = true -normalize_doc_attributes = true -format_strings = true -format_macro_matchers = false -format_macro_bodies = true -skip_macro_invocations = [] -hex_literal_case = "Preserve" -empty_item_single_line = true -struct_lit_single_line = true -fn_single_line = false -where_single_line = false -imports_indent = "Block" -imports_layout = "Mixed" -imports_granularity = "Preserve" -group_imports = "Preserve" reorder_imports = true -reorder_modules = true -reorder_impl_items = false -type_punctuation_density = "Wide" -space_before_colon = false -space_after_colon = true -spaces_around_ranges = false -binop_separator = "Front" -remove_nested_parens = true -combine_control_expr = true -short_array_element_width_threshold = 10 -overflow_delimited_expr = false -struct_field_align_threshold = 0 -enum_discrim_align_threshold = 0 -match_arm_blocks = true -match_arm_leading_pipes = "Never" -force_multiline_blocks = false -fn_params_layout = "Tall" -brace_style = "SameLineWhere" -control_brace_style = "AlwaysSameLine" -trailing_semicolon = true -trailing_comma = "Vertical" -match_block_trailing_comma = false -blank_lines_upper_bound = 1 -blank_lines_lower_bound = 0 -edition = "2024" -style_edition = "2024" -inline_attribute_width = 0 -format_generated_files = true -generated_marker_line_search_limit = 5 -merge_derives = true -use_try_shorthand = false -use_field_init_shorthand = false -force_explicit_abi = true -condense_wildcard_suffixes = false -color = "Auto" -required_version = "1.8.0" -unstable_features = true -disable_all_formatting = false -skip_children = false -show_parse_errors = true -error_on_line_overflow = false -error_on_unformatted = false -ignore = [] -emit_mode = "Files" -make_backup = false +merge_derives = true \ No newline at end of file