Support for bivariate convolution & normalization with offset (#126)

* Add bivariate-convolution
* Add pair-wise convolution + tests + benches
* Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md
* cross-base2k normalization with positive offset
* clippy & fix CI doctest avx compile error
* more streamlined bounds derivation for normalization
* Working cross-base2k normalization with pos/neg offset
* Update normalization API & tests
* Add glwe tensoring test
* Add relinearization + preliminary test
* Fix GGLWEToGGSW key infos
* Add (X,Y) convolution by const (1, Y) poly
* Faster normalization test + add bench for cnv_by_const
* Update changelog
This commit is contained in:
Jean-Philippe Bossuat
2025-12-21 16:56:42 +01:00
committed by GitHub
parent 76424d0ab5
commit 4e90e08a71
219 changed files with 6571 additions and 5041 deletions

View File

@@ -62,6 +62,7 @@ jobs:
- name: Tests (AVX enabled) - name: Tests (AVX enabled)
if: steps.avxcheck.outputs.supported == 'true' if: steps.avxcheck.outputs.supported == 'true'
run: | run: |
RUSTDOCFLAGS="-C target-feature=+avx2 -C target-feature=+fma" \
RUSTFLAGS="-C target-feature=+avx2,+fma" \ RUSTFLAGS="-C target-feature=+avx2,+fma" \
cargo test --workspace --features enable-avx cargo test --workspace --features enable-avx

View File

@@ -1,10 +1,101 @@
# CHANGELOG # CHANGELOG
## [0.4.0] - 2025-10-27 ## [0.4.1] - 2025-11-26
### Summary
- Update convolution API to match spqlios-arithmetic & removed API for bivariate tensoring.
## `poulpy-hal`
- Removed `Backend` generic from `VecZnxBigAllocBytesImpl`.
- Add `CnvPVecL` and `CnvPVecR` structs.
- Add `CnvPVecBytesOf` and `CnvPVecAlloc` traits.
- Add `Convolution` trait, which regroups the following methods:
- `cnv_prepare_left_tmp_bytes`
- `cnv_prepare_left`
- `cnv_prepare_right_tmp_bytes`
- `cnv_prepare_right`
- `cnv_by_const_apply`
- `cnv_by_const_apply_tmp_bytes`
- `cnv_apply_dft_tmp_bytes`
- `cnv_apply_dft`
- `cnv_pairwise_apply_dft_tmp_bytes`
- `cnv_pairwise_apply_dft`
- Add the following Reim4 traits:
- `Reim4Convolution`
- `Reim4Convolution1Coeff`
- `Reim4Convolution2Coeffs`
- `Reim4Save1BlkContiguous`
- Add the following traits:
- `i64Save1BlkContiguous`
- `i64Extract1BlkContiguous`
- `i64ConvolutionByConst1Coeff`
- `i64ConvolutionByConst2Coeffs`
- Update signature `Reim4Extract1Blk` to `Reim4Extract1BlkContiguous`.
- Add fft64 backend reference code for
- `reim4_save_1blk_to_reim_contiguous_ref`
- `reim4_convolution_1coeff_ref`
- `reim4_convolution_2coeffs_ref`
- `convolution_prepare_left`
- `convolution_prepare_right`
- `convolution_apply_dft_tmp_bytes`
- `convolution_apply_dft`
- `convolution_pairwise_apply_dft_tmp_bytes`
- `convolution_pairwise_apply_dft`
- `convolution_by_const_apply_tmp_bytes`
- `convolution_by_const_apply`
- Add `take_cnv_pvec_left` and `take_cnv_pvec_right` methods to `ScratchTakeBasic` trait.
- Add the following tests methods for convolution:
- `test_convolution`
- `test_convolution_by_const`
- `test_convolution_pairwise`
- Add the following benches methods for convolution:
- `bench_cnv_prepare_left`
- `bench_cnv_prepare_right`
- `bench_cnv_apply_dft`
- `bench_cnv_pairwise_apply_dft`
- `bench_cnv_by_const`
- Update normalization API and OEP to take `res_offset: i64`. This allows the user to specify a bit-shift (positive or negative) applied to the normalization. Behavior-wise, the bit-shift is applied before the normalization (i.e. before applying mod 1 reduction). Since this is an API break, opportunity was taken to also re-order inputs for better consistency.
- `VecZnxNormalize` & `VecZnxNormalizeImpl`
- `VecZnxBigNormalize` & `VecZnxBigNormalizeImpl`
This change completes the road to unlocking full support for cross-base2k normalization, along with arbitrary positive/negative offset. Code is not ensured to be optimal, but correctness is ensured.
## `poulpy-cpu-ref`
- Implemented `ConvolutionImpl` OPE on `FFT64Ref` backend.
- Add benchmark for convolution.
- Add test for convolution.
## `poulpy-cpu-avx`
- Implemented `ConvolutionImpl` OPE on `FFT64Avx` backend.
- Add benchmark for convolution.
- Add test for convolution.
- Add fft64 AVX code for
- `reim4_save_1blk_to_reim_contiguous_avx`
- `reim4_convolution_1coeff_avx`
- `reim4_convolution_2coeffs_avx`
## `poulpy-core`
- Renamed `size` to `limbs`.
- Add `GLWEMulPlain` trait:
- `glwe_mul_plain_tmp_bytes`
- `glwe_mul_plain`
- `glwe_mul_plain_inplace`
- Add `GLWEMulConst` trait:
- `glwe_mul_const_tmp_bytes`
- `glwe_mul_const`
- `glwe_mul_const_inplace`
- Add `GLWETensoring` trait:
- `glwe_tensor_apply_tmp_bytes`
- `glwe_tensor_apply`
- `glwe_tensor_relinearize_tmp_bytes`
- `glwe_tensor_relinearize`
- Add method tests:
- `test_glwe_tensoring`
## [0.4.0] - 2025-11-20
### Summary ### Summary
- Full support for base2k operations. - Full support for base2k operations.
- Many improvments to BDD arithmetic. - Many improvements to BDD arithmetic.
- Removal of **poulpy-backend** & spqlios backend. - Removal of **poulpy-backend** & spqlios backend.
- Addition of individual crates for each specific backend. - Addition of individual crates for each specific backend.
- Some minor bug fixes. - Some minor bug fixes.
@@ -28,7 +119,7 @@
- Improved Cmux speed - Improved Cmux speed
### `poulpy-cpu-ref` ### `poulpy-cpu-ref`
- A new crate that provides the refernce CPU implementation of **poulpy-hal**. This replaces the previous **poulpy-backend/cpu_ref**. - A new crate that provides the reference CPU implementation of **poulpy-hal**. This replaces the previous **poulpy-backend/cpu_ref**.
### `poulpy-cpu-avx` ### `poulpy-cpu-avx`
- A new crate that provides an AVX/FMA accelerated CPU implementation of **poulpy-hal**. This replaces the previous **poulpy-backend/cpu_avx**. - A new crate that provides an AVX/FMA accelerated CPU implementation of **poulpy-hal**. This replaces the previous **poulpy-backend/cpu_avx**.
@@ -76,7 +167,7 @@
- Added functionality-based traits, which removes the need to import the low-levels traits of `poulpy-hal` and makes backend agnostic code much cleaner. For example instead of having to import each individual traits required for the encryption of a GLWE, only the trait `GLWEEncryptSk` is needed. - Added functionality-based traits, which removes the need to import the low-levels traits of `poulpy-hal` and makes backend agnostic code much cleaner. For example instead of having to import each individual traits required for the encryption of a GLWE, only the trait `GLWEEncryptSk` is needed.
### `poulpy-schemes` ### `poulpy-schemes`
- Added basic framework for binary decicion circuit (BDD) arithmetic along with some operations. - Added basic framework for binary decision circuit (BDD) arithmetic along with some operations.
## [0.2.0] - 2025-09-15 ## [0.2.0] - 2025-09-15

1
Cargo.lock generated
View File

@@ -370,6 +370,7 @@ dependencies = [
"poulpy-cpu-avx", "poulpy-cpu-avx",
"poulpy-cpu-ref", "poulpy-cpu-ref",
"poulpy-hal", "poulpy-hal",
"rand",
"rug", "rug",
] ]

View File

@@ -23,6 +23,7 @@ byteorder = {workspace = true}
bytemuck = {workspace = true} bytemuck = {workspace = true}
once_cell = {workspace = true} once_cell = {workspace = true}
paste = {workspace = true} paste = {workspace = true}
rand = {workspace = true}
[[bench]] [[bench]]
name = "external_product_glwe_fft64" name = "external_product_glwe_fft64"

View File

@@ -87,22 +87,9 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) {
let mut sk_dft: GLWESecretPrepared<Vec<u8>, BackendImpl> = GLWESecretPrepared::alloc(&module, rank); let mut sk_dft: GLWESecretPrepared<Vec<u8>, BackendImpl> = GLWESecretPrepared::alloc(&module, rank);
sk_dft.prepare(&module, &sk); sk_dft.prepare(&module, &sk);
ct_ggsw.encrypt_sk( ct_ggsw.encrypt_sk(&module, &pt_rgsw, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct_glwe_in.encrypt_zero_sk( ct_glwe_in.encrypt_zero_sk(&module, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ggsw_prepared: GGSWPrepared<Vec<u8>, BackendImpl> = GGSWPrepared::alloc_from_infos(&module, &ct_ggsw); let mut ggsw_prepared: GGSWPrepared<Vec<u8>, BackendImpl> = GGSWPrepared::alloc_from_infos(&module, &ct_ggsw);
ggsw_prepared.prepare(&module, &ct_ggsw, scratch.borrow()); ggsw_prepared.prepare(&module, &ct_ggsw, scratch.borrow());
@@ -190,22 +177,9 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) {
let mut sk_dft: GLWESecretPrepared<Vec<u8>, BackendImpl> = GLWESecretPrepared::alloc(&module, rank); let mut sk_dft: GLWESecretPrepared<Vec<u8>, BackendImpl> = GLWESecretPrepared::alloc(&module, rank);
sk_dft.prepare(&module, &sk); sk_dft.prepare(&module, &sk);
ct_ggsw.encrypt_sk( ct_ggsw.encrypt_sk(&module, &pt_rgsw, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct_glwe.encrypt_zero_sk( ct_glwe.encrypt_zero_sk(&module, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ggsw_prepared: GGSWPrepared<Vec<u8>, BackendImpl> = GGSWPrepared::alloc_from_infos(&module, &ct_ggsw); let mut ggsw_prepared: GGSWPrepared<Vec<u8>, BackendImpl> = GGSWPrepared::alloc_from_infos(&module, &ct_ggsw);
ggsw_prepared.prepare(&module, &ct_ggsw, scratch.borrow()); ggsw_prepared.prepare(&module, &ct_ggsw, scratch.borrow());

View File

@@ -75,12 +75,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {
let mut scratch: ScratchOwned<BackendImpl> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BackendImpl> = ScratchOwned::alloc(
GLWESwitchingKey::encrypt_sk_tmp_bytes(&module, &gglwe_atk_layout) GLWESwitchingKey::encrypt_sk_tmp_bytes(&module, &gglwe_atk_layout)
| GLWE::encrypt_sk_tmp_bytes(&module, &glwe_in_layout) | GLWE::encrypt_sk_tmp_bytes(&module, &glwe_in_layout)
| GLWE::keyswitch_tmp_bytes( | GLWE::keyswitch_tmp_bytes(&module, &glwe_out_layout, &glwe_in_layout, &gglwe_atk_layout),
&module,
&glwe_out_layout,
&glwe_in_layout,
&gglwe_atk_layout,
),
); );
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
@@ -93,22 +88,9 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {
let mut sk_in_dft: GLWESecretPrepared<Vec<u8>, BackendImpl> = GLWESecretPrepared::alloc(&module, rank); let mut sk_in_dft: GLWESecretPrepared<Vec<u8>, BackendImpl> = GLWESecretPrepared::alloc(&module, rank);
sk_in_dft.prepare(&module, &sk_in); sk_in_dft.prepare(&module, &sk_in);
ksk.encrypt_sk( ksk.encrypt_sk(&module, -1, &sk_in, &mut source_xa, &mut source_xe, scratch.borrow());
&module,
-1,
&sk_in,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct_in.encrypt_zero_sk( ct_in.encrypt_zero_sk(&module, &sk_in_dft, &mut source_xa, &mut source_xe, scratch.borrow());
&module,
&sk_in_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ksk_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, _> = let mut ksk_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, _> =
GLWEAutomorphismKeyPrepared::alloc_from_infos(&module, &ksk); GLWEAutomorphismKeyPrepared::alloc_from_infos(&module, &ksk);
@@ -206,22 +188,9 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) {
let mut sk_out: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&glwe_layout); let mut sk_out: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&glwe_layout);
sk_out.fill_ternary_prob(0.5, &mut source_xs); sk_out.fill_ternary_prob(0.5, &mut source_xs);
ksk.encrypt_sk( ksk.encrypt_sk(&module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
&module,
&sk_in,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct.encrypt_zero_sk( ct.encrypt_zero_sk(&module, &sk_in_dft, &mut source_xa, &mut source_xe, scratch.borrow());
&module,
&sk_in_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ksk_prepared: GLWESwitchingKeyPrepared<Vec<u8>, _> = GLWESwitchingKeyPrepared::alloc_from_infos(&module, &ksk); let mut ksk_prepared: GLWESwitchingKeyPrepared<Vec<u8>, _> = GLWESwitchingKeyPrepared::alloc_from_infos(&module, &ksk);
ksk_prepared.prepare(&module, &ksk, scratch.borrow()); ksk_prepared.prepare(&module, &ksk, scratch.borrow());
@@ -249,9 +218,5 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) {
group.finish(); group.finish();
} }
criterion_group!( criterion_group!(benches, bench_keyswitch_glwe_fft64, bench_keyswitch_glwe_inplace_fft64);
benches,
bench_keyswitch_glwe_fft64,
bench_keyswitch_glwe_inplace_fft64
);
criterion_main!(benches); criterion_main!(benches);

View File

@@ -75,13 +75,7 @@ where
a.dnum() a.dnum()
); );
assert_eq!( assert_eq!(res.dsize(), a.dsize(), "res dnum: {} != a dnum: {}", res.dsize(), a.dsize());
res.dsize(),
a.dsize(),
"res dnum: {} != a dnum: {}",
res.dsize(),
a.dsize()
);
assert_eq!(res.base2k(), a.base2k()); assert_eq!(res.base2k(), a.base2k());
@@ -139,13 +133,7 @@ where
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos, K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
assert_eq!( assert_eq!(res.rank(), key.rank(), "key rank: {} != key rank: {}", res.rank(), key.rank());
res.rank(),
key.rank(),
"key rank: {} != key rank: {}",
res.rank(),
key.rank()
);
let cols_out: usize = (key.rank_out() + 1).into(); let cols_out: usize = (key.rank_out() + 1).into();
let cols_in: usize = key.rank_in().into(); let cols_in: usize = key.rank_in().into();

View File

@@ -218,13 +218,13 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let base2k_a: usize = a.base2k().into(); let a_base2k: usize = a.base2k().into();
let base2k_key: usize = key.base2k().into(); let key_base2k: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into(); let res_base2k: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_a != base2k_key { if a_base2k != key_base2k {
let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: a.n(), n: a.n(),
base2k: key.base2k(), base2k: key.base2k(),
@@ -236,30 +236,14 @@ where
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_add_small_inplace(&mut res_big, i, a_conv.data(), i); self.vec_znx_big_add_small_inplace(&mut res_big, i, a_conv.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
} }
} else { } else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_1,
);
} }
}; };
} }
@@ -272,12 +256,12 @@ where
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let base2k_key: usize = key.base2k().into(); let key_base2k: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into(); let res_base2k: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_res != base2k_key { if res_base2k != key_base2k {
let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: res.n(), n: res.n(),
base2k: key.base2k(), base2k: key.base2k(),
@@ -289,30 +273,14 @@ where
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_add_small_inplace(&mut res_big, i, res_conv.data(), i); self.vec_znx_big_add_small_inplace(&mut res_big, i, res_conv.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
} }
} else { } else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i); self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_1,
);
} }
}; };
} }
@@ -327,13 +295,13 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let base2k_a: usize = a.base2k().into(); let a_base2k: usize = a.base2k().into();
let base2k_key: usize = key.base2k().into(); let key_base2k: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into(); let res_base2k: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_a != base2k_key { if a_base2k != key_base2k {
let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: a.n(), n: a.n(),
base2k: key.base2k(), base2k: key.base2k(),
@@ -345,30 +313,14 @@ where
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_sub_small_inplace(&mut res_big, i, a_conv.data(), i); self.vec_znx_big_sub_small_inplace(&mut res_big, i, a_conv.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
} }
} else { } else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i); self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_1,
);
} }
}; };
} }
@@ -383,13 +335,13 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let base2k_a: usize = a.base2k().into(); let a_base2k: usize = a.base2k().into();
let base2k_key: usize = key.base2k().into(); let key_base2k: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into(); let res_base2k: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_a != base2k_key { if a_base2k != key_base2k {
let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: a.n(), n: a.n(),
base2k: key.base2k(), base2k: key.base2k(),
@@ -401,30 +353,14 @@ where
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a_conv.data(), i); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a_conv.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
} }
} else { } else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_1,
);
} }
}; };
} }
@@ -437,12 +373,12 @@ where
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let base2k_key: usize = key.base2k().into(); let key_base2k: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into(); let res_base2k: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_res != base2k_key { if res_base2k != key_base2k {
let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: res.n(), n: res.n(),
base2k: key.base2k(), base2k: key.base2k(),
@@ -454,30 +390,14 @@ where
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_sub_small_inplace(&mut res_big, i, res_conv.data(), i); self.vec_znx_big_sub_small_inplace(&mut res_big, i, res_conv.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
} }
} else { } else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i); self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_1,
);
} }
}; };
} }
@@ -490,12 +410,12 @@ where
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let base2k_key: usize = key.base2k().into(); let key_base2k: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into(); let res_base2k: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_res != base2k_key { if res_base2k != key_base2k {
let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: res.n(), n: res.n(),
base2k: key.base2k(), base2k: key.base2k(),
@@ -507,30 +427,14 @@ where
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res_conv.data(), i); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res_conv.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
} }
} else { } else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_1,
);
} }
}; };
} }

View File

@@ -120,13 +120,13 @@ where
R: GGSWInfos, R: GGSWInfos,
A: GGLWEInfos, A: GGLWEInfos,
{ {
let base2k_tsk: usize = tsk_infos.base2k().into(); let tsk_base2k: usize = tsk_infos.base2k().into();
let rank: usize = res_infos.rank().into(); let rank: usize = res_infos.rank().into();
let cols: usize = rank + 1; let cols: usize = rank + 1;
let res_size: usize = res_infos.size(); let res_size: usize = res_infos.size();
let a_size: usize = res_infos.max_k().as_usize().div_ceil(base2k_tsk); let a_size: usize = res_infos.max_k().as_usize().div_ceil(tsk_base2k);
let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size); let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size);
let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size); let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size);
@@ -146,15 +146,15 @@ where
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
let base2k_res: usize = res.base2k().into(); let res_base2k: usize = res.base2k().into();
let base2k_tsk: usize = tsk.base2k().into(); let tsk_base2k: usize = tsk.base2k().into();
assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk));
let rank: usize = res.rank().into(); let rank: usize = res.rank().into();
let cols: usize = rank + 1; let cols: usize = rank + 1;
let res_conv_size: usize = res.max_k().as_usize().div_ceil(base2k_tsk); let res_conv_size: usize = res.max_k().as_usize().div_ceil(tsk_base2k);
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size); let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size);
let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size); let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size);
@@ -163,33 +163,17 @@ where
for row in 0..res.dnum().as_usize() { for row in 0..res.dnum().as_usize() {
let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0); let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0);
if base2k_res == base2k_tsk { if res_base2k == tsk_base2k {
for col_i in 0..cols - 1 { for col_i in 0..cols - 1 {
self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1); self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1);
} }
self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0); self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0);
} else { } else {
for i in 0..cols - 1 { for i in 0..cols - 1 {
self.vec_znx_normalize( self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, i + 1, scratch_2);
base2k_tsk,
&mut a_0,
0,
base2k_res,
glwe_mi_1.data(),
i + 1,
scratch_2,
);
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0); self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0);
} }
self.vec_znx_normalize( self.vec_znx_normalize(&mut a_0, tsk_base2k, 0, 0, glwe_mi_1.data(), res_base2k, 0, scratch_2);
base2k_tsk,
&mut a_0,
0,
base2k_res,
glwe_mi_1.data(),
0,
scratch_2,
);
} }
ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2) ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2)
@@ -267,13 +251,16 @@ fn ggsw_expand_rows_internal<M, R, C, A, T, BE: Backend>(
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0); module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0);
let res_base2k: usize = res.base2k().as_usize();
for j in 0..cols { for j in 0..cols {
module.vec_znx_big_normalize( module.vec_znx_big_normalize(
res.base2k().as_usize(),
res.at_mut(row, col).data_mut(), res.at_mut(row, col).data_mut(),
res_base2k,
0,
j, j,
tsk.base2k().as_usize(),
&res_big, &res_big,
tsk.base2k().as_usize(),
j, j,
scratch_1, scratch_1,
); );

View File

@@ -56,12 +56,8 @@ where
rank: Rank(1), rank: Rank(1),
}; };
GLWE::bytes_of( GLWE::bytes_of(self.n().into(), lwe_infos.base2k(), lwe_infos.k(), 1u32.into())
self.n().into(), + GLWE::bytes_of_from_infos(glwe_infos)
lwe_infos.base2k(),
lwe_infos.k(),
1u32.into(),
) + GLWE::bytes_of_from_infos(glwe_infos)
+ self.glwe_keyswitch_tmp_bytes(&res_infos, glwe_infos, key_infos) + self.glwe_keyswitch_tmp_bytes(&res_infos, glwe_infos, key_infos)
} }

View File

@@ -73,11 +73,12 @@ where
} }
self.vec_znx_normalize( self.vec_znx_normalize(
ksk.base2k().into(),
&mut glwe.data, &mut glwe.data,
ksk.base2k().into(),
0,
0, 0,
lwe.base2k().into(),
&a_conv, &a_conv,
lwe.base2k().into(),
0, 0,
scratch_2, scratch_2,
); );
@@ -89,11 +90,12 @@ where
} }
self.vec_znx_normalize( self.vec_znx_normalize(
ksk.base2k().into(),
&mut glwe.data, &mut glwe.data,
ksk.base2k().into(),
0,
1, 1,
lwe.base2k().into(),
&a_conv, &a_conv,
lwe.base2k().into(),
0, 0,
scratch_2, scratch_2,
); );

View File

@@ -33,10 +33,21 @@ impl<DataSelf: DataRef> GLWE<DataSelf> {
} }
} }
pub trait GLWEDecrypt<BE: Backend> pub trait GLWEDecrypt<BE: Backend> {
fn glwe_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GLWEInfos;
fn glwe_decrypt<R, P, S>(&self, res: &R, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToRef,
P: GLWEPlaintextToMut,
S: GLWESecretPreparedToRef<BE>;
}
impl<BE: Backend> GLWEDecrypt<BE> for Module<BE>
where where
Self: Sized Self: ModuleN
+ ModuleN
+ VecZnxDftBytesOf + VecZnxDftBytesOf
+ VecZnxNormalizeTmpBytes + VecZnxNormalizeTmpBytes
+ VecZnxBigBytesOf + VecZnxBigBytesOf
@@ -46,6 +57,7 @@ where
+ VecZnxBigAddInplace<BE> + VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE> + VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>, + VecZnxBigNormalize<BE>,
Scratch<BE>: ScratchTakeBasic,
{ {
fn glwe_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize fn glwe_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize
where where
@@ -60,7 +72,6 @@ where
R: GLWEToRef, R: GLWEToRef,
P: GLWEPlaintextToMut, P: GLWEPlaintextToMut,
S: GLWESecretPreparedToRef<BE>, S: GLWESecretPreparedToRef<BE>,
Scratch<BE>: ScratchTakeBasic,
{ {
let res: &GLWE<&[u8]> = &res.to_ref(); let res: &GLWE<&[u8]> = &res.to_ref();
let pt: &mut GLWEPlaintext<&mut [u8]> = &mut pt.to_ref(); let pt: &mut GLWEPlaintext<&mut [u8]> = &mut pt.to_ref();
@@ -94,32 +105,12 @@ where
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
self.vec_znx_big_add_small_inplace(&mut c0_big, 0, &res.data, 0); self.vec_znx_big_add_small_inplace(&mut c0_big, 0, &res.data, 0);
let res_base2k: usize = res.base2k().into();
// pt = norm(BIG(m + e)) // pt = norm(BIG(m + e))
self.vec_znx_big_normalize( self.vec_znx_big_normalize(&mut pt.data, res_base2k, 0, 0, &c0_big, res_base2k, 0, scratch_1);
res.base2k().into(),
&mut pt.data,
0,
res.base2k().into(),
&c0_big,
0,
scratch_1,
);
pt.base2k = res.base2k(); pt.base2k = res.base2k();
pt.k = pt.k().min(res.k()); pt.k = pt.k().min(res.k());
} }
} }
impl<BE: Backend> GLWEDecrypt<BE> for Module<BE> where
Self: ModuleN
+ VecZnxDftBytesOf
+ VecZnxNormalizeTmpBytes
+ VecZnxBigBytesOf
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
{
}

View File

@@ -0,0 +1,97 @@
use poulpy_hal::{
api::{ScratchTakeBasic, SvpPPolBytesOf},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut},
};
use crate::{
GLWEDecrypt, ScratchTakeCore,
layouts::{GLWEInfos, GLWEPlaintext, GLWESecretPrepared, GLWESecretTensor, GLWESecretTensorPrepared, GLWETensor},
};
impl GLWETensor<Vec<u8>> {
pub fn decrypt_tmp_bytes<A, M, BE: Backend>(module: &M, a_infos: &A) -> usize
where
A: GLWEInfos,
M: GLWETensorDecrypt<BE>,
{
module.glwe_tensor_decrypt_tmp_bytes(a_infos)
}
}
impl<DataSelf: DataRef> GLWETensor<DataSelf> {
pub fn decrypt<P, S0, S1, M, BE: Backend>(
&self,
module: &M,
pt: &mut GLWEPlaintext<P>,
sk: &GLWESecretPrepared<S0, BE>,
sk_tensor: &GLWESecretTensorPrepared<S1, BE>,
scratch: &mut Scratch<BE>,
) where
P: DataMut,
S0: DataRef,
S1: DataRef,
M: GLWETensorDecrypt<BE>,
Scratch<BE>: ScratchTakeBasic,
{
module.glwe_tensor_decrypt(self, pt, sk, sk_tensor, scratch);
}
}
pub trait GLWETensorDecrypt<BE: Backend> {
fn glwe_tensor_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GLWEInfos;
fn glwe_tensor_decrypt<R, P, S0, S1>(
&self,
res: &GLWETensor<R>,
pt: &mut GLWEPlaintext<P>,
sk: &GLWESecretPrepared<S0, BE>,
sk_tensor: &GLWESecretTensorPrepared<S1, BE>,
scratch: &mut Scratch<BE>,
) where
R: DataRef,
P: DataMut,
S0: DataRef,
S1: DataRef;
}
impl<BE: Backend> GLWETensorDecrypt<BE> for Module<BE>
where
Self: GLWEDecrypt<BE> + SvpPPolBytesOf,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn glwe_tensor_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GLWEInfos,
{
self.glwe_decrypt_tmp_bytes(infos)
}
fn glwe_tensor_decrypt<R, P, S0, S1>(
&self,
res: &GLWETensor<R>,
pt: &mut GLWEPlaintext<P>,
sk: &GLWESecretPrepared<S0, BE>,
sk_tensor: &GLWESecretTensorPrepared<S1, BE>,
scratch: &mut Scratch<BE>,
) where
R: DataRef,
P: DataMut,
S0: DataRef,
S1: DataRef,
{
let rank: usize = sk.rank().as_usize();
let (mut sk_grouped, scratch_1) = scratch.take_glwe_secret_prepared(self, (GLWESecretTensor::pairs(rank) + rank).into());
for i in 0..rank {
sk_grouped.data.at_mut(i, 0).copy_from_slice(sk.data.at(i, 0));
}
for i in 0..sk_grouped.rank().as_usize() - rank {
sk_grouped.data.at_mut(i + rank, 0).copy_from_slice(sk_tensor.data.at(i, 0));
}
self.glwe_decrypt(res, pt, &sk_grouped, scratch_1);
}
}

View File

@@ -1,5 +1,7 @@
mod glwe; mod glwe;
mod glwe_tensor;
mod lwe; mod lwe;
pub use glwe::*; pub use glwe::*;
pub use glwe_tensor::*;
pub use lwe::*; pub use lwe::*;

View File

@@ -63,10 +63,7 @@ impl Distribution {
TAG_ZERO => Distribution::ZERO, TAG_ZERO => Distribution::ZERO,
TAG_NONE => Distribution::NONE, TAG_NONE => Distribution::NONE,
_ => { _ => {
return Err(std::io::Error::new( return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid tag"));
std::io::ErrorKind::InvalidData,
"Invalid tag",
));
} }
}; };
Ok(dist) Ok(dist)

View File

@@ -77,9 +77,7 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
self.glwe_encrypt_sk_tmp_bytes(infos) self.glwe_encrypt_sk_tmp_bytes(infos).max(self.vec_znx_normalize_tmp_bytes()) + GLWEPlaintext::bytes_of_from_infos(infos)
.max(self.vec_znx_normalize_tmp_bytes())
+ GLWEPlaintext::bytes_of_from_infos(infos)
} }
fn gglwe_compressed_encrypt_sk<R, P, S>( fn gglwe_compressed_encrypt_sk<R, P, S>(

View File

@@ -101,24 +101,13 @@ where
for i in 0..rank { for i in 0..rank {
for j in 0..rank { for j in 0..rank {
self.vec_znx_copy( self.vec_znx_copy(&mut sk_ij.as_vec_znx_mut(), j, &sk_tensor.at(i, j).as_vec_znx(), 0);
&mut sk_ij.as_vec_znx_mut(),
j,
&sk_tensor.at(i, j).as_vec_znx(),
0,
);
} }
let (seed_xa_tmp, _) = source_xa.branch(); let (seed_xa_tmp, _) = source_xa.branch();
res.at_mut(i).encrypt_sk( res.at_mut(i)
self, .encrypt_sk(self, &sk_ij, &sk_prepared, seed_xa_tmp, source_xe, scratch_3);
&sk_ij,
&sk_prepared,
seed_xa_tmp,
source_xe,
scratch_3,
);
} }
} }
} }

View File

@@ -112,14 +112,7 @@ where
sk_out_prepared.prepare(self, &sk_out); sk_out_prepared.prepare(self, &sk_out);
} }
self.gglwe_compressed_encrypt_sk( self.gglwe_compressed_encrypt_sk(res, &sk.data, &sk_out_prepared, seed_xa, source_xe, scratch_1);
res,
&sk.data,
&sk_out_prepared,
seed_xa,
source_xe,
scratch_1,
);
res.set_p(p); res.set_p(p);
} }

View File

@@ -104,12 +104,7 @@ where
let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into()); let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into());
for i in 0..sk_in.rank().into() { for i in 0..sk_in.rank().into() {
self.vec_znx_switch_ring( self.vec_znx_switch_ring(&mut sk_in_tmp.as_vec_znx_mut(), i, &sk_in.data.as_vec_znx(), i);
&mut sk_in_tmp.as_vec_znx_mut(),
i,
&sk_in.data.as_vec_znx(),
i,
);
} }
let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank()); let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank());

View File

@@ -102,13 +102,6 @@ where
sk_prepared.prepare(self, sk); sk_prepared.prepare(self, sk);
sk_tensor.prepare(self, sk, scratch_2); sk_tensor.prepare(self, sk, scratch_2);
self.gglwe_compressed_encrypt_sk( self.gglwe_compressed_encrypt_sk(res, &sk_tensor.data, &sk_prepared, seed_xa, source_xe, scratch_2);
res,
&sk_tensor.data,
&sk_prepared,
seed_xa,
source_xe,
scratch_2,
);
} }
} }

View File

@@ -160,14 +160,7 @@ where
tmp_pt.data.zero(); // zeroes for next iteration tmp_pt.data.zero(); // zeroes for next iteration
self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i); self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i);
self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1);
self.glwe_encrypt_sk( self.glwe_encrypt_sk(&mut res.at_mut(row_i, col_i), &tmp_pt, sk, source_xa, source_xe, scrach_1);
&mut res.at_mut(row_i, col_i),
&tmp_pt,
sk,
source_xa,
source_xe,
scrach_1,
);
} }
} }
} }

View File

@@ -97,12 +97,7 @@ where
for i in 0..rank { for i in 0..rank {
for j in 0..rank { for j in 0..rank {
self.vec_znx_copy( self.vec_znx_copy(&mut sk_ij.as_vec_znx_mut(), j, &sk_tensor.at(i, j).as_vec_znx(), 0);
&mut sk_ij.as_vec_znx_mut(),
j,
&sk_tensor.at(i, j).as_vec_znx(),
0,
);
} }
res.at_mut(i) res.at_mut(i)

View File

@@ -76,9 +76,7 @@ where
where where
A: GGSWInfos, A: GGSWInfos,
{ {
self.glwe_encrypt_sk_tmp_bytes(infos) self.glwe_encrypt_sk_tmp_bytes(infos).max(self.vec_znx_normalize_tmp_bytes()) + GLWEPlaintext::bytes_of_from_infos(infos)
.max(self.vec_znx_normalize_tmp_bytes())
+ GLWEPlaintext::bytes_of_from_infos(infos)
} }
fn ggsw_encrypt_sk<R, P, S>( fn ggsw_encrypt_sk<R, P, S>(

View File

@@ -402,15 +402,7 @@ where
let mut ci_big = self.vec_znx_idft_apply_consume(ci_dft); let mut ci_big = self.vec_znx_idft_apply_consume(ci_dft);
// ci_big = u * pk[i] + e // ci_big = u * pk[i] + e
self.vec_znx_big_add_normal( self.vec_znx_big_add_normal(base2k, &mut ci_big, 0, pk.k().into(), source_xe, SIGMA, SIGMA_BOUND);
base2k,
&mut ci_big,
0,
pk.k().into(),
source_xe,
SIGMA,
SIGMA_BOUND,
);
// ci_big = u * pk[i] + e + m (if col = i) // ci_big = u * pk[i] + e + m (if col = i)
if let Some((pt, col)) = pt if let Some((pt, col)) = pt
@@ -420,7 +412,7 @@ where
} }
// ct[i] = norm(ci_big) // ct[i] = norm(ci_big)
self.vec_znx_big_normalize(base2k, &mut res.data, i, base2k, &ci_big, 0, scratch_2); self.vec_znx_big_normalize(&mut res.data, base2k, 0, i, &ci_big, base2k, 0, scratch_2);
} }
} }
} }
@@ -487,12 +479,7 @@ where
let sk: GLWESecretPrepared<&[u8], BE> = sk.to_ref(); let sk: GLWESecretPrepared<&[u8], BE> = sk.to_ref();
if compressed { if compressed {
assert_eq!( assert_eq!(ct.cols(), 1, "invalid glwe: compressed tag=true but #cols={} != 1", ct.cols())
ct.cols(),
1,
"invalid glwe: compressed tag=true but #cols={} != 1",
ct.cols()
)
} }
assert!( assert!(
@@ -537,7 +524,7 @@ where
let ci_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(ci_dft); let ci_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(ci_dft);
// use c[0] as buffer, which is overwritten later by the normalization step // use c[0] as buffer, which is overwritten later by the normalization step
self.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); self.vec_znx_big_normalize(&mut ci, base2k, 0, 0, &ci_big, base2k, 0, scratch_3);
// c0_tmp = -c[i] * s[i] (use c[0] as buffer) // c0_tmp = -c[i] * s[i] (use c[0] as buffer)
self.vec_znx_sub_inplace(&mut c0, 0, &ci, 0); self.vec_znx_sub_inplace(&mut c0, 0, &ci, 0);
@@ -555,6 +542,6 @@ where
} }
// c[0] = norm(c[0]) // c[0] = norm(c[0])
self.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1); self.vec_znx_normalize(ct, base2k, 0, 0, &c0, base2k, 0, scratch_1);
} }
} }

View File

@@ -130,14 +130,7 @@ where
sk_out_prepared.prepare(self, &sk_out); sk_out_prepared.prepare(self, &sk_out);
} }
self.gglwe_encrypt_sk( self.gglwe_encrypt_sk(res, &sk.data, &sk_out_prepared, source_xa, source_xe, scratch_1);
res,
&sk.data,
&sk_out_prepared,
source_xa,
source_xe,
scratch_1,
);
res.set_p(p); res.set_p(p);
} }

View File

@@ -78,8 +78,7 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
self.gglwe_encrypt_sk_tmp_bytes(infos) self.gglwe_encrypt_sk_tmp_bytes(infos).max(ScalarZnx::bytes_of(self.n(), 1))
.max(ScalarZnx::bytes_of(self.n(), 1))
+ ScalarZnx::bytes_of(self.n(), infos.rank_in().into()) + ScalarZnx::bytes_of(self.n(), infos.rank_in().into())
+ self.bytes_of_glwe_secret_prepared_from_infos(infos) + self.bytes_of_glwe_secret_prepared_from_infos(infos)
} }
@@ -111,12 +110,7 @@ where
let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into()); let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(self.n(), sk_in.rank().into());
for i in 0..sk_in.rank().into() { for i in 0..sk_in.rank().into() {
self.vec_znx_switch_ring( self.vec_znx_switch_ring(&mut sk_in_tmp.as_vec_znx_mut(), i, &sk_in.data.as_vec_znx(), i);
&mut sk_in_tmp.as_vec_znx_mut(),
i,
&sk_in.data.as_vec_znx(),
i,
);
} }
let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank()); let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(self, sk_out.rank());
@@ -130,14 +124,7 @@ where
sk_out_tmp.dist = sk_out.dist; sk_out_tmp.dist = sk_out.dist;
self.gglwe_encrypt_sk( self.gglwe_encrypt_sk(res, &sk_in_tmp, &sk_out_tmp, source_xa, source_xe, scratch_2);
res,
&sk_in_tmp,
&sk_out_tmp,
source_xa,
source_xe,
scratch_2,
);
*res.input_degree() = sk_in.n(); *res.input_degree() = sk_in.n();
*res.output_degree() = sk_out.n(); *res.output_degree() = sk_out.n();

View File

@@ -103,13 +103,6 @@ where
sk_prepared.prepare(self, sk); sk_prepared.prepare(self, sk);
sk_tensor.prepare(self, sk, scratch_2); sk_tensor.prepare(self, sk, scratch_2);
self.gglwe_encrypt_sk( self.gglwe_encrypt_sk(res, &sk_tensor.data, &sk_prepared, source_xa, source_xe, scratch_2);
res,
&sk_tensor.data,
&sk_prepared,
source_xa,
source_xe,
scratch_2,
);
} }
} }

View File

@@ -107,13 +107,6 @@ where
sk_lwe_as_glwe_prep.prepare(self, &sk_lwe_as_glwe); sk_lwe_as_glwe_prep.prepare(self, &sk_lwe_as_glwe);
} }
self.gglwe_encrypt_sk( self.gglwe_encrypt_sk(res, &sk_glwe.data, &sk_lwe_as_glwe_prep, source_xa, source_xe, scratch_1);
res,
&sk_glwe.data,
&sk_lwe_as_glwe_prep,
source_xa,
source_xe,
scratch_1,
);
} }
} }

View File

@@ -70,21 +70,9 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey");
infos.dsize().0, assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey");
1, assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey");
"dsize > 1 is not supported for LWESwitchingKey"
);
assert_eq!(
infos.rank_in().0,
1,
"rank_in > 1 is not supported for LWESwitchingKey"
);
assert_eq!(
infos.rank_out().0,
1,
"rank_out > 1 is not supported for LWESwitchingKey"
);
GLWESecret::bytes_of(self.n().into(), Rank(1)) GLWESecret::bytes_of(self.n().into(), Rank(1))
+ GLWESecretPrepared::bytes_of(self, Rank(1)) + GLWESecretPrepared::bytes_of(self, Rank(1))
+ GLWESwitchingKey::encrypt_sk_tmp_bytes(self, infos) + GLWESwitchingKey::encrypt_sk_tmp_bytes(self, infos)
@@ -125,13 +113,6 @@ where
sk_glwe_in.data.at_mut(0, 0)[sk_lwe_in.n().into()..].fill(0); sk_glwe_in.data.at_mut(0, 0)[sk_lwe_in.n().into()..].fill(0);
self.vec_znx_automorphism_inplace(-1, &mut sk_glwe_in.data.as_vec_znx_mut(), 0, scratch_2); self.vec_znx_automorphism_inplace(-1, &mut sk_glwe_in.data.as_vec_znx_mut(), 0, scratch_2);
self.glwe_switching_key_encrypt_sk( self.glwe_switching_key_encrypt_sk(res, &sk_glwe_in, &sk_glwe_out, source_xa, source_xe, scratch_2);
res,
&sk_glwe_in,
&sk_glwe_out,
source_xa,
source_xe,
scratch_2,
);
} }
} }

View File

@@ -106,13 +106,6 @@ where
sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n().into()..].fill(0); sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n().into()..].fill(0);
self.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); self.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1);
self.gglwe_encrypt_sk( self.gglwe_encrypt_sk(res, &sk_lwe_as_glwe.data, sk_glwe, source_xa, source_xe, scratch_1);
res,
&sk_lwe_as_glwe.data,
sk_glwe,
source_xa,
source_xe,
scratch_1,
);
} }
} }

View File

@@ -35,20 +35,8 @@ where
let a: &GGSW<&[u8]> = &a.to_ref(); let a: &GGSW<&[u8]> = &a.to_ref();
let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); let b: &GGSWPrepared<&[u8], BE> = &b.to_ref();
assert_eq!( assert_eq!(res.rank(), a.rank(), "res rank: {} != a rank: {}", res.rank(), a.rank());
res.rank(), assert_eq!(res.rank(), b.rank(), "res rank: {} != b rank: {}", res.rank(), b.rank());
a.rank(),
"res rank: {} != a rank: {}",
res.rank(),
a.rank()
);
assert_eq!(
res.rank(),
b.rank(),
"res rank: {} != b rank: {}",
res.rank(),
b.rank()
);
assert_eq!(res.base2k(), a.base2k()); assert_eq!(res.base2k(), a.base2k());
@@ -80,13 +68,7 @@ where
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!( assert_eq!(res.rank(), a.rank(), "res rank: {} != a rank: {}", res.rank(), a.rank());
res.rank(),
a.rank(),
"res rank: {} != a rank: {}",
res.rank(),
a.rank()
);
for row in 0..res.dnum().into() { for row in 0..res.dnum().into() {
for col in 0..(res.rank() + 1).into() { for col in 0..(res.rank() + 1).into() {

View File

@@ -110,12 +110,12 @@ where
assert_eq!(ggsw.n(), res.n()); assert_eq!(ggsw.n(), res.n());
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, ggsw)); assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, ggsw));
let base2k_res: usize = res.base2k().as_usize(); let res_base2k: usize = res.base2k().as_usize();
let base2k_ggsw: usize = ggsw.base2k().as_usize(); let ggsw_base2k: usize = ggsw.base2k().as_usize();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise
let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_ggsw { let res_big: VecZnxBig<&mut [u8], BE> = if res_base2k != ggsw_base2k {
let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: res.n(), n: res.n(),
base2k: ggsw.base2k(), base2k: ggsw.base2k(),
@@ -130,15 +130,7 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
for j in 0..(res.rank() + 1).into() { for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, j, &res_big, ggsw_base2k, j, scratch_1);
base2k_res,
res.data_mut(),
j,
base2k_ggsw,
&res_big,
j,
scratch_1,
);
} }
} }
@@ -155,13 +147,13 @@ where
assert_eq!(a.n(), res.n()); assert_eq!(a.n(), res.n());
assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, a, ggsw)); assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, a, ggsw));
let base2k_a: usize = a.base2k().into(); let a_base2k: usize = a.base2k().into();
let base2k_ggsw: usize = ggsw.base2k().into(); let ggsw_base2k: usize = ggsw.base2k().into();
let base2k_res: usize = res.base2k().into(); let res_base2k: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise
let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_ggsw { let res_big: VecZnxBig<&mut [u8], BE> = if a_base2k != ggsw_base2k {
let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: a.n(), n: a.n(),
base2k: ggsw.base2k(), base2k: ggsw.base2k(),
@@ -176,15 +168,7 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
for j in 0..(res.rank() + 1).into() { for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, j, &res_big, ggsw_base2k, j, scratch_1);
base2k_res,
res.data_mut(),
j,
base2k_ggsw,
&res_big,
j,
scratch_1,
);
} }
} }
} }
@@ -231,10 +215,7 @@ where
A: GLWEInfos, A: GLWEInfos,
B: GGSWInfos, B: GGSWInfos,
{ {
let in_size: usize = a_infos let in_size: usize = a_infos.k().div_ceil(b_infos.base2k()).div_ceil(b_infos.dsize().into()) as usize;
.k()
.div_ceil(b_infos.base2k())
.div_ceil(b_infos.dsize().into()) as usize;
let out_size: usize = res_infos.size(); let out_size: usize = res_infos.size();
let ggsw_size: usize = b_infos.size(); let ggsw_size: usize = b_infos.size();
let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size); let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);

View File

@@ -265,23 +265,9 @@ fn pack_core<A, K, H, M, BE: Backend>(
// Propagates to next accumulator // Propagates to next accumulator
if acc_prev[0].value { if acc_prev[0].value {
pack_core( pack_core(module, Some(&acc_prev[0].data), acc_next, i + 1, auto_keys, scratch);
module,
Some(&acc_prev[0].data),
acc_next,
i + 1,
auto_keys,
scratch,
);
} else { } else {
pack_core( pack_core(module, None::<&GLWE<Vec<u8>>>, acc_next, i + 1, auto_keys, scratch);
module,
None::<&GLWE<Vec<u8>>>,
acc_next,
i + 1,
auto_keys,
scratch,
);
} }
} }
} }
@@ -319,11 +305,7 @@ fn combine<B, K, H, M, BE: Backend>(
let log_n: usize = acc.data.n().log2(); let log_n: usize = acc.data.n().log2();
let a: &mut GLWE<Vec<u8>> = &mut acc.data; let a: &mut GLWE<Vec<u8>> = &mut acc.data;
let gal_el: i64 = if i == 0 { let gal_el: i64 = if i == 0 { -1 } else { module.galois_element(1 << (i - 1)) };
-1
} else {
module.galois_element(1 << (i - 1))
};
let t: i64 = 1 << (log_n - i - 1); let t: i64 = 1 << (log_n - i - 1);

View File

@@ -88,8 +88,7 @@ where
let key: &K = if i == 0 { let key: &K = if i == 0 {
keys.get_automorphism_key(-1).unwrap() keys.get_automorphism_key(-1).unwrap()
} else { } else {
keys.get_automorphism_key(self.galois_element(1 << (i - 1))) keys.get_automorphism_key(self.galois_element(1 << (i - 1))).unwrap()
.unwrap()
}; };
for j in 0..t { for j in 0..t {

View File

@@ -169,11 +169,7 @@ where
for i in skip..log_n { for i in skip..log_n {
self.glwe_rsh(1, res, scratch); self.glwe_rsh(1, res, scratch);
let p: i64 = if i == 0 { let p: i64 = if i == 0 { -1 } else { self.galois_element(1 << (i - 1)) };
-1
} else {
self.galois_element(1 << (i - 1))
};
if let Some(key) = keys.get_automorphism_key(p) { if let Some(key) = keys.get_automorphism_key(p) {
self.glwe_automorphism_add_inplace(res, key, scratch); self.glwe_automorphism_add_inplace(res, key, scratch);

View File

@@ -148,19 +148,8 @@ where
res.rank_out(), res.rank_out(),
b.rank_out() b.rank_out()
); );
assert!( assert!(res.dnum() <= a.dnum(), "res.dnum()={} > a.dnum()={}", res.dnum(), a.dnum());
res.dnum() <= a.dnum(), assert_eq!(res.dsize(), a.dsize(), "res dsize: {} != a dsize: {}", res.dsize(), a.dsize());
"res.dnum()={} > a.dnum()={}",
res.dnum(),
a.dnum()
);
assert_eq!(
res.dsize(),
a.dsize(),
"res dsize: {} != a dsize: {}",
res.dsize(),
a.dsize()
);
assert_eq!(res.base2k(), a.base2k()); assert_eq!(res.base2k(), a.base2k());
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();

View File

@@ -105,13 +105,13 @@ where
scratch.available(), scratch.available(),
); );
let base2k_a: usize = a.base2k().into(); let a_base2k: usize = a.base2k().into();
let base2k_key: usize = key.base2k().into(); let key_base2k: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into(); let res_base2k: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise
let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_key { let res_big: VecZnxBig<&mut [u8], BE> = if a_base2k != key_base2k {
let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: a.n(), n: a.n(),
base2k: key.base2k(), base2k: key.base2k(),
@@ -126,15 +126,7 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
for i in 0..(res.rank() + 1).into() { for i in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_1,
);
} }
} }
@@ -169,12 +161,12 @@ where
scratch.available(), scratch.available(),
); );
let base2k_res: usize = res.base2k().as_usize(); let res_base2k: usize = res.base2k().as_usize();
let base2k_key: usize = key.base2k().as_usize(); let key_base2k: usize = key.base2k().as_usize();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise
let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_key { let res_big: VecZnxBig<&mut [u8], BE> = if res_base2k != key_base2k {
let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: res.n(), n: res.n(),
base2k: key.base2k(), base2k: key.base2k(),
@@ -190,15 +182,7 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
for i in 0..(res.rank() + 1).into() { for i in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize( self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_1);
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_1,
);
} }
} }
} }

View File

@@ -137,13 +137,7 @@ impl GGLWECompressed<Vec<u8>> {
); );
GGLWECompressed { GGLWECompressed {
data: MatZnx::alloc( data: MatZnx::alloc(n.into(), dnum.into(), rank_in.into(), 1, k.0.div_ceil(base2k.0) as usize),
n.into(),
dnum.into(),
rank_in.into(),
1,
k.0.div_ceil(base2k.0) as usize,
),
k, k,
base2k, base2k,
dsize, dsize,
@@ -181,13 +175,7 @@ impl GGLWECompressed<Vec<u8>> {
dsize.0, dsize.0,
); );
MatZnx::bytes_of( MatZnx::bytes_of(n.into(), dnum.into(), rank_in.into(), 1, k.0.div_ceil(base2k.0) as usize)
n.into(),
dnum.into(),
rank_in.into(),
1,
k.0.div_ceil(base2k.0) as usize,
)
} }
} }

View File

@@ -127,13 +127,7 @@ impl GGSWCompressed<Vec<u8>> {
); );
GGSWCompressed { GGSWCompressed {
data: MatZnx::alloc( data: MatZnx::alloc(n.into(), dnum.into(), (rank + 1).into(), 1, k.0.div_ceil(base2k.0) as usize),
n.into(),
dnum.into(),
(rank + 1).into(),
1,
k.0.div_ceil(base2k.0) as usize,
),
k, k,
base2k, base2k,
dsize, dsize,
@@ -171,13 +165,7 @@ impl GGSWCompressed<Vec<u8>> {
dsize.0, dsize.0,
); );
MatZnx::bytes_of( MatZnx::bytes_of(n.into(), dnum.into(), (rank + 1).into(), 1, k.0.div_ceil(base2k.0) as usize)
n.into(),
dnum.into(),
(rank + 1).into(),
1,
k.0.div_ceil(base2k.0) as usize,
)
} }
} }

View File

@@ -95,15 +95,7 @@ impl GLWETensorKeyCompressed<Vec<u8>> {
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1);
GLWETensorKeyCompressed(GGLWECompressed::alloc( GLWETensorKeyCompressed(GGLWECompressed::alloc(n, base2k, k, Rank(pairs), rank, dnum, dsize))
n,
base2k,
k,
Rank(pairs),
rank,
dnum,
dsize,
))
} }
pub fn bytes_of_from_infos<A>(infos: &A) -> usize pub fn bytes_of_from_infos<A>(infos: &A) -> usize

View File

@@ -100,13 +100,7 @@ impl GLWEToLWESwitchingKeyCompressed<Vec<u8>> {
1, 1,
"dsize > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" "dsize > 1 is unsupported for GLWEToLWESwitchingKeyCompressed"
); );
Self::alloc( Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank_in(), infos.dnum())
infos.n(),
infos.base2k(),
infos.k(),
infos.rank_in(),
infos.dnum(),
)
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self {

View File

@@ -88,11 +88,7 @@ impl LWESwitchingKeyCompressed<Vec<u8>> {
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKeyCompressed");
infos.dsize().0,
1,
"dsize > 1 is not supported for LWESwitchingKeyCompressed"
);
assert_eq!( assert_eq!(
infos.rank_in().0, infos.rank_in().0,
1, 1,
@@ -122,11 +118,7 @@ impl LWESwitchingKeyCompressed<Vec<u8>> {
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKeyCompressed");
infos.dsize().0,
1,
"dsize > 1 is not supported for LWESwitchingKeyCompressed"
);
assert_eq!( assert_eq!(
infos.rank_in().0, infos.rank_in().0,
1, 1,

View File

@@ -98,13 +98,7 @@ impl LWEToGLWEKeyCompressed<Vec<u8>> {
1, 1,
"rank_in > 1 is not supported for LWEToGLWESwitchingKeyCompressed" "rank_in > 1 is not supported for LWEToGLWESwitchingKeyCompressed"
); );
Self::alloc( Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank_out(), infos.dnum())
infos.n(),
infos.base2k(),
infos.k(),
infos.rank_out(),
infos.dnum(),
)
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self {

View File

@@ -129,13 +129,7 @@ impl<D: DataRef> fmt::Debug for GLWE<D> {
impl<D: DataRef> fmt::Display for GLWE<D> { impl<D: DataRef> fmt::Display for GLWE<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!( write!(f, "GLWE: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data)
f,
"GLWE: base2k={} k={}: {}",
self.base2k().0,
self.k().0,
self.data
)
} }
} }

View File

@@ -73,13 +73,7 @@ impl<D: Data> GLWEInfos for GLWEPlaintext<D> {
impl<D: DataRef> fmt::Display for GLWEPlaintext<D> { impl<D: DataRef> fmt::Display for GLWEPlaintext<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!( write!(f, "GLWEPlaintext: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data)
f,
"GLWEPlaintext: base2k={} k={}: {}",
self.base2k().0,
self.k().0,
self.data
)
} }
} }

View File

@@ -10,7 +10,7 @@ use poulpy_hal::{
}; };
use crate::{ use crate::{
ScratchTakeCore, GetDistribution, ScratchTakeCore,
dist::Distribution, dist::Distribution,
layouts::{ layouts::{
Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESecretToMut, GLWESecretToRef, LWEInfos, Rank, Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESecretToMut, GLWESecretToRef, LWEInfos, Rank,
@@ -30,6 +30,12 @@ impl GLWESecretTensor<Vec<u8>> {
} }
} }
impl<D: Data> GetDistribution for GLWESecretTensor<D> {
fn dist(&self) -> &Distribution {
&self.dist
}
}
impl<D: Data> LWEInfos for GLWESecretTensor<D> { impl<D: Data> LWEInfos for GLWESecretTensor<D> {
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
Base2K(0) Base2K(0)
@@ -204,12 +210,14 @@ where
let idx: usize = i * rank + j - (i * (i + 1) / 2); let idx: usize = i * rank + j - (i * (i + 1) / 2);
self.svp_apply_dft_to_dft(&mut a_ij_dft, 0, &a_prepared.data, j, &a_dft, i); self.svp_apply_dft_to_dft(&mut a_ij_dft, 0, &a_prepared.data, j, &a_dft, i);
self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0); self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
base2k,
&mut res.data.as_vec_znx_mut(), &mut res.data.as_vec_znx_mut(),
idx,
base2k, base2k,
0,
idx,
&a_ij_big, &a_ij_big,
base2k,
0, 0,
scratch_4, scratch_4,
); );

View File

@@ -3,7 +3,7 @@ use poulpy_hal::{
source::Source, source::Source,
}; };
use crate::layouts::{Base2K, Degree, GLWEInfos, LWEInfos, Rank, SetGLWEInfos, TorusPrecision}; use crate::layouts::{Base2K, Degree, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, Rank, SetGLWEInfos, TorusPrecision};
use std::fmt; use std::fmt;
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
@@ -68,13 +68,7 @@ impl<D: DataRef> fmt::Debug for GLWETensor<D> {
impl<D: DataRef> fmt::Display for GLWETensor<D> { impl<D: DataRef> fmt::Display for GLWETensor<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!( write!(f, "GLWETensor: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data)
f,
"GLWETensor: base2k={} k={}: {}",
self.base2k().0,
self.k().0,
self.data
)
} }
} }
@@ -93,9 +87,10 @@ impl GLWETensor<Vec<u8>> {
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self {
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1); let cols: usize = rank.as_usize() + 1;
let pairs: usize = (((cols + 1) * cols) >> 1).max(1);
GLWETensor { GLWETensor {
data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize), data: VecZnx::alloc(n.into(), pairs, k.0.div_ceil(base2k.0) as usize),
base2k, base2k,
k, k,
rank, rank,
@@ -110,36 +105,27 @@ impl GLWETensor<Vec<u8>> {
} }
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize {
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1); let cols: usize = rank.as_usize() + 1;
VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize) let pairs: usize = (((cols + 1) * cols) >> 1).max(1);
VecZnx::bytes_of(n.into(), pairs, k.0.div_ceil(base2k.0) as usize)
} }
} }
pub trait GLWETensorToRef { impl<D: DataRef> GLWEToRef for GLWETensor<D> {
fn to_ref(&self) -> GLWETensor<&[u8]>; fn to_ref(&self) -> GLWE<&[u8]> {
} GLWE {
impl<D: DataRef> GLWETensorToRef for GLWETensor<D> {
fn to_ref(&self) -> GLWETensor<&[u8]> {
GLWETensor {
k: self.k, k: self.k,
base2k: self.base2k, base2k: self.base2k,
data: self.data.to_ref(), data: self.data.to_ref(),
rank: self.rank,
} }
} }
} }
pub trait GLWETensorToMut { impl<D: DataMut> GLWEToMut for GLWETensor<D> {
fn to_mut(&mut self) -> GLWETensor<&mut [u8]>; fn to_mut(&mut self) -> GLWE<&mut [u8]> {
} GLWE {
impl<D: DataMut> GLWETensorToMut for GLWETensor<D> {
fn to_mut(&mut self) -> GLWETensor<&mut [u8]> {
GLWETensor {
k: self.k, k: self.k,
base2k: self.base2k, base2k: self.base2k,
rank: self.rank,
data: self.data.to_mut(), data: self.data.to_mut(),
} }
} }

View File

@@ -48,7 +48,9 @@ impl<D: Data> GLWEInfos for GLWETensorKey<D> {
impl<D: Data> GGLWEInfos for GLWETensorKey<D> { impl<D: Data> GGLWEInfos for GLWETensorKey<D> {
fn rank_in(&self) -> Rank { fn rank_in(&self) -> Rank {
self.rank_out() let rank_out: usize = self.rank_out().as_usize();
let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1);
pairs.into()
} }
fn rank_out(&self) -> Rank { fn rank_out(&self) -> Rank {
@@ -86,7 +88,9 @@ impl GLWEInfos for GLWETensorKeyLayout {
impl GGLWEInfos for GLWETensorKeyLayout { impl GGLWEInfos for GLWETensorKeyLayout {
fn rank_in(&self) -> Rank { fn rank_in(&self) -> Rank {
self.rank let rank_out: usize = self.rank_out().as_usize();
let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1);
pairs.into()
} }
fn dsize(&self) -> Dsize { fn dsize(&self) -> Dsize {
@@ -127,11 +131,6 @@ impl GLWETensorKey<Vec<u8>> {
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWETensorKey"
);
Self::alloc( Self::alloc(
infos.n(), infos.n(),
infos.base2k(), infos.base2k(),
@@ -151,11 +150,6 @@ impl GLWETensorKey<Vec<u8>> {
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWETensorKey"
);
Self::bytes_of( Self::bytes_of(
infos.n(), infos.n(),
infos.base2k(), infos.base2k(),

View File

@@ -137,58 +137,22 @@ impl GLWEToLWEKey<Vec<u8>> {
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for GLWEToLWEKey");
infos.rank_out().0, assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWEKey");
1, Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank_in(), infos.dnum())
"rank_out > 1 is not supported for GLWEToLWEKey"
);
assert_eq!(
infos.dsize().0,
1,
"dsize > 1 is not supported for GLWEToLWEKey"
);
Self::alloc(
infos.n(),
infos.base2k(),
infos.k(),
infos.rank_in(),
infos.dnum(),
)
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self {
GLWEToLWEKey(GLWESwitchingKey::alloc( GLWEToLWEKey(GLWESwitchingKey::alloc(n, base2k, k, rank_in, Rank(1), dnum, Dsize(1)))
n,
base2k,
k,
rank_in,
Rank(1),
dnum,
Dsize(1),
))
} }
pub fn bytes_of_from_infos<A>(infos: &A) -> usize pub fn bytes_of_from_infos<A>(infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for GLWEToLWEKey");
infos.rank_out().0, assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWEKey");
1, Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.rank_in(), infos.dnum())
"rank_out > 1 is not supported for GLWEToLWEKey"
);
assert_eq!(
infos.dsize().0,
1,
"dsize > 1 is not supported for GLWEToLWEKey"
);
Self::bytes_of(
infos.n(),
infos.base2k(),
infos.k(),
infos.rank_in(),
infos.dnum(),
)
} }
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize {

View File

@@ -96,7 +96,7 @@ impl<D: DataRef> LWE<D> {
} }
impl<D: DataMut> LWE<D> { impl<D: DataMut> LWE<D> {
pub fn data_mut(&mut self) -> &VecZnx<D> { pub fn data_mut(&mut self) -> &mut VecZnx<D> {
&mut self.data &mut self.data
} }
} }
@@ -109,13 +109,7 @@ impl<D: DataRef> fmt::Debug for LWE<D> {
impl<D: DataRef> fmt::Display for LWE<D> { impl<D: DataRef> fmt::Display for LWE<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!( write!(f, "LWE: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data)
f,
"LWE: base2k={} k={}: {}",
self.base2k().0,
self.k().0,
self.data
)
} }
} }

View File

@@ -71,13 +71,7 @@ impl LWEPlaintext<Vec<u8>> {
impl<D: DataRef> fmt::Display for LWEPlaintext<D> { impl<D: DataRef> fmt::Display for LWEPlaintext<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!( write!(f, "LWEPlaintext: base2k={} k={}: {}", self.base2k().0, self.k().0, self.data)
f,
"LWEPlaintext: base2k={} k={}: {}",
self.base2k().0,
self.k().0,
self.data
)
} }
} }

View File

@@ -106,55 +106,23 @@ impl LWESwitchingKey<Vec<u8>> {
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey");
infos.dsize().0, assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey");
1, assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey");
"dsize > 1 is not supported for LWESwitchingKey"
);
assert_eq!(
infos.rank_in().0,
1,
"rank_in > 1 is not supported for LWESwitchingKey"
);
assert_eq!(
infos.rank_out().0,
1,
"rank_out > 1 is not supported for LWESwitchingKey"
);
Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.dnum()) Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.dnum())
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self {
LWESwitchingKey(GLWESwitchingKey::alloc( LWESwitchingKey(GLWESwitchingKey::alloc(n, base2k, k, Rank(1), Rank(1), dnum, Dsize(1)))
n,
base2k,
k,
Rank(1),
Rank(1),
dnum,
Dsize(1),
))
} }
pub fn bytes_of_from_infos<A>(infos: &A) -> usize pub fn bytes_of_from_infos<A>(infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey");
infos.dsize().0, assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey");
1, assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey");
"dsize > 1 is not supported for LWESwitchingKey"
);
assert_eq!(
infos.rank_in().0,
1,
"rank_in > 1 is not supported for LWESwitchingKey"
);
assert_eq!(
infos.rank_out().0,
1,
"rank_out > 1 is not supported for LWESwitchingKey"
);
Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.dnum()) Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.dnum())
} }

View File

@@ -136,59 +136,23 @@ impl LWEToGLWEKey<Vec<u8>> {
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWEKey");
infos.rank_in().0, assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWEKey");
1,
"rank_in > 1 is not supported for LWEToGLWEKey"
);
assert_eq!(
infos.dsize().0,
1,
"dsize > 1 is not supported for LWEToGLWEKey"
);
Self::alloc( Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank_out(), infos.dnum())
infos.n(),
infos.base2k(),
infos.k(),
infos.rank_out(),
infos.dnum(),
)
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self {
LWEToGLWEKey(GLWESwitchingKey::alloc( LWEToGLWEKey(GLWESwitchingKey::alloc(n, base2k, k, Rank(1), rank_out, dnum, Dsize(1)))
n,
base2k,
k,
Rank(1),
rank_out,
dnum,
Dsize(1),
))
} }
pub fn bytes_of_from_infos<A>(infos: &A) -> usize pub fn bytes_of_from_infos<A>(infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWEKey");
infos.rank_in().0, assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWEKey");
1, Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.rank_out(), infos.dnum())
"rank_in > 1 is not supported for LWEToGLWEKey"
);
assert_eq!(
infos.dsize().0,
1,
"dsize > 1 is not supported for LWEToGLWEKey"
);
Self::bytes_of(
infos.n(),
infos.base2k(),
infos.k(),
infos.rank_out(),
infos.dnum(),
)
} }
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize { pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize {

View File

@@ -94,13 +94,7 @@ where
infos.rank_out(), infos.rank_out(),
"rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared" "rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared"
); );
self.alloc_gglwe_to_ggsw_key_prepared( self.alloc_gglwe_to_ggsw_key_prepared(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize())
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
} }
fn alloc_gglwe_to_ggsw_key_prepared( fn alloc_gglwe_to_ggsw_key_prepared(
@@ -127,13 +121,7 @@ where
infos.rank_out(), infos.rank_out(),
"rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared" "rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared"
); );
self.bytes_of_gglwe_to_ggsw( self.bytes_of_gglwe_to_ggsw(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize())
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
} }
fn bytes_of_gglwe_to_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { fn bytes_of_gglwe_to_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {

View File

@@ -93,13 +93,7 @@ where
A: GGSWInfos, A: GGSWInfos,
{ {
assert_eq!(self.ring_degree(), infos.n()); assert_eq!(self.ring_degree(), infos.n());
self.alloc_ggsw_prepared( self.alloc_ggsw_prepared(infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank())
infos.base2k(),
infos.k(),
infos.dnum(),
infos.dsize(),
infos.rank(),
)
} }
fn bytes_of_ggsw_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize { fn bytes_of_ggsw_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize {
@@ -125,13 +119,7 @@ where
A: GGSWInfos, A: GGSWInfos,
{ {
assert_eq!(self.ring_degree(), infos.n()); assert_eq!(self.ring_degree(), infos.n());
self.bytes_of_ggsw_prepared( self.bytes_of_ggsw_prepared(infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank())
infos.base2k(),
infos.k(),
infos.dnum(),
infos.dsize(),
infos.rank(),
)
} }
fn ggsw_prepare_tmp_bytes<A>(&self, infos: &A) -> usize fn ggsw_prepare_tmp_bytes<A>(&self, infos: &A) -> usize

View File

@@ -17,9 +17,7 @@ where
} }
fn automorphism_key_infos(&self) -> GGLWELayout { fn automorphism_key_infos(&self) -> GGLWELayout {
self.get(self.keys().next().unwrap()) self.get(self.keys().next().unwrap()).unwrap().gglwe_layout()
.unwrap()
.gglwe_layout()
} }
} }
@@ -110,13 +108,7 @@ where
infos.rank_out(), infos.rank_out(),
"rank_in != rank_out is not supported for AutomorphismKeyPrepared" "rank_in != rank_out is not supported for AutomorphismKeyPrepared"
); );
self.alloc_glwe_automorphism_key_prepared( self.alloc_glwe_automorphism_key_prepared(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize())
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
} }
fn bytes_of_glwe_automorphism_key_prepared( fn bytes_of_glwe_automorphism_key_prepared(
@@ -139,13 +131,7 @@ where
infos.rank_out(), infos.rank_out(),
"rank_in != rank_out is not supported for AutomorphismKeyPrepared" "rank_in != rank_out is not supported for AutomorphismKeyPrepared"
); );
self.bytes_of_glwe_automorphism_key_prepared( self.bytes_of_glwe_automorphism_key_prepared(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize())
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
} }
fn prepare_glwe_automorphism_key_tmp_bytes<A>(&self, infos: &A) -> usize fn prepare_glwe_automorphism_key_tmp_bytes<A>(&self, infos: &A) -> usize

View File

@@ -86,7 +86,6 @@ where
{ {
let mut res: GLWESecretPrepared<&mut [u8], _> = res.to_mut(); let mut res: GLWESecretPrepared<&mut [u8], _> = res.to_mut();
let other: GLWESecret<&[u8]> = other.to_ref(); let other: GLWESecret<&[u8]> = other.to_ref();
for i in 0..res.rank().into() { for i in 0..res.rank().into() {
self.svp_prepare(&mut res.data, i, &other.data, i); self.svp_prepare(&mut res.data, i, &other.data, i);
} }

View File

@@ -0,0 +1,180 @@
use poulpy_hal::{
api::{SvpPPolAlloc, SvpPPolBytesOf},
layouts::{Backend, Data, DataMut, DataRef, Module, SvpPPol, SvpPPolToMut, SvpPPolToRef, ZnxInfos},
};
use crate::{
GetDistribution, GetDistributionMut,
dist::Distribution,
layouts::{
Base2K, Degree, GLWEInfos, GLWESecretPrepared, GLWESecretPreparedFactory, GLWESecretPreparedToMut,
GLWESecretPreparedToRef, GLWESecretTensor, GLWESecretToRef, GetDegree, LWEInfos, Rank, TorusPrecision,
},
};
pub struct GLWESecretTensorPrepared<D: Data, B: Backend> {
pub(crate) data: SvpPPol<D, B>,
pub(crate) rank: Rank,
pub(crate) dist: Distribution,
}
impl<D: DataRef, BE: Backend> GetDistribution for GLWESecretTensorPrepared<D, BE> {
fn dist(&self) -> &Distribution {
&self.dist
}
}
impl<D: DataMut, BE: Backend> GetDistributionMut for GLWESecretTensorPrepared<D, BE> {
fn dist_mut(&mut self) -> &mut Distribution {
&mut self.dist
}
}
impl<D: Data, B: Backend> LWEInfos for GLWESecretTensorPrepared<D, B> {
fn base2k(&self) -> Base2K {
Base2K(0)
}
fn k(&self) -> TorusPrecision {
TorusPrecision(0)
}
fn n(&self) -> Degree {
Degree(self.data.n() as u32)
}
fn size(&self) -> usize {
self.data.size()
}
}
impl<D: Data, B: Backend> GLWEInfos for GLWESecretTensorPrepared<D, B> {
fn rank(&self) -> Rank {
self.rank
}
}
pub trait GLWESecretTensorPreparedFactory<B: Backend> {
fn alloc_glwe_secret_tensor_prepared(&self, rank: Rank) -> GLWESecretTensorPrepared<Vec<u8>, B>;
fn alloc_glwe_secret_tensor_prepared_from_infos<A>(&self, infos: &A) -> GLWESecretTensorPrepared<Vec<u8>, B>
where
A: GLWEInfos;
fn bytes_of_glwe_secret_tensor_prepared(&self, rank: Rank) -> usize;
fn bytes_of_glwe_secret_tensor_prepared_from_infos<A>(&self, infos: &A) -> usize
where
A: GLWEInfos;
fn prepare_glwe_secret_tensor<R, O>(&self, res: &mut R, other: &O)
where
R: GLWESecretPreparedToMut<B> + GetDistributionMut,
O: GLWESecretToRef + GetDistribution;
}
impl<B: Backend> GLWESecretTensorPreparedFactory<B> for Module<B>
where
Self: GLWESecretPreparedFactory<B>,
{
fn alloc_glwe_secret_tensor_prepared(&self, rank: Rank) -> GLWESecretTensorPrepared<Vec<u8>, B> {
GLWESecretTensorPrepared {
data: self.svp_ppol_alloc(GLWESecretTensor::pairs(rank.into())),
rank,
dist: Distribution::NONE,
}
}
fn alloc_glwe_secret_tensor_prepared_from_infos<A>(&self, infos: &A) -> GLWESecretTensorPrepared<Vec<u8>, B>
where
A: GLWEInfos,
{
assert_eq!(self.ring_degree(), infos.n());
self.alloc_glwe_secret_tensor_prepared(infos.rank())
}
fn bytes_of_glwe_secret_tensor_prepared(&self, rank: Rank) -> usize {
self.bytes_of_svp_ppol(GLWESecretTensor::pairs(rank.into()))
}
fn bytes_of_glwe_secret_tensor_prepared_from_infos<A>(&self, infos: &A) -> usize
where
A: GLWEInfos,
{
assert_eq!(self.ring_degree(), infos.n());
self.bytes_of_glwe_secret_prepared(infos.rank())
}
fn prepare_glwe_secret_tensor<R, O>(&self, res: &mut R, other: &O)
where
R: GLWESecretPreparedToMut<B> + GetDistributionMut,
O: GLWESecretToRef + GetDistribution,
{
self.prepare_glwe_secret(res, other);
}
}
impl<B: Backend> GLWESecretTensorPrepared<Vec<u8>, B> {
pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self
where
A: GLWEInfos,
M: GLWESecretTensorPreparedFactory<B>,
{
module.alloc_glwe_secret_tensor_prepared_from_infos(infos)
}
pub fn alloc<M>(module: &M, rank: Rank) -> Self
where
M: GLWESecretTensorPreparedFactory<B>,
{
module.alloc_glwe_secret_tensor_prepared(rank)
}
pub fn bytes_of_from_infos<A, M>(module: &M, infos: &A) -> usize
where
A: GLWEInfos,
M: GLWESecretTensorPreparedFactory<B>,
{
module.bytes_of_glwe_secret_tensor_prepared_from_infos(infos)
}
pub fn bytes_of<M>(module: &M, rank: Rank) -> usize
where
M: GLWESecretTensorPreparedFactory<B>,
{
module.bytes_of_glwe_secret_tensor_prepared(rank)
}
}
impl<D: Data, B: Backend> GLWESecretTensorPrepared<D, B> {
pub fn n(&self) -> Degree {
Degree(self.data.n() as u32)
}
pub fn rank(&self) -> Rank {
Rank(self.data.cols() as u32)
}
}
impl<D: DataMut, B: Backend> GLWESecretTensorPrepared<D, B> {
pub fn prepare<M, O>(&mut self, module: &M, other: &O)
where
M: GLWESecretTensorPreparedFactory<B>,
O: GLWESecretToRef + GetDistribution,
{
module.prepare_glwe_secret_tensor(self, other);
}
}
impl<D: DataRef, B: Backend> GLWESecretPreparedToRef<B> for GLWESecretTensorPrepared<D, B> {
fn to_ref(&self) -> GLWESecretPrepared<&[u8], B> {
GLWESecretPrepared {
data: self.data.to_ref(),
dist: self.dist,
}
}
}
impl<D: DataMut, B: Backend> GLWESecretPreparedToMut<B> for GLWESecretTensorPrepared<D, B> {
fn to_mut(&mut self) -> GLWESecretPrepared<&mut [u8], B> {
GLWESecretPrepared {
dist: self.dist,
data: self.data.to_mut(),
}
}
}

View File

@@ -34,7 +34,7 @@ impl<D: Data, B: Backend> GLWEInfos for GLWETensorKeyPrepared<D, B> {
impl<D: Data, B: Backend> GGLWEInfos for GLWETensorKeyPrepared<D, B> { impl<D: Data, B: Backend> GGLWEInfos for GLWETensorKeyPrepared<D, B> {
fn rank_in(&self) -> Rank { fn rank_in(&self) -> Rank {
self.rank_out() self.0.rank_in()
} }
fn rank_out(&self) -> Rank { fn rank_out(&self) -> Rank {
@@ -70,18 +70,7 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
assert_eq!( self.alloc_tensor_key_prepared(infos.base2k(), infos.k(), infos.dnum(), infos.dsize(), infos.rank_out())
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for TensorKeyPrepared"
);
self.alloc_tensor_key_prepared(
infos.base2k(),
infos.k(),
infos.dnum(),
infos.dsize(),
infos.rank_out(),
)
} }
fn bytes_of_tensor_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { fn bytes_of_tensor_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
@@ -93,13 +82,7 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
self.bytes_of_tensor_key_prepared( self.bytes_of_tensor_key_prepared(infos.base2k(), infos.k(), infos.rank(), infos.dnum(), infos.dsize())
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
} }
fn prepare_tensor_key_tmp_bytes<A>(&self, infos: &A) -> usize fn prepare_tensor_key_tmp_bytes<A>(&self, infos: &A) -> usize

View File

@@ -73,11 +73,7 @@ where
1, 1,
"rank_out > 1 is not supported for GLWEToLWEKeyPrepared" "rank_out > 1 is not supported for GLWEToLWEKeyPrepared"
); );
debug_assert_eq!( debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWEKeyPrepared");
infos.dsize().0,
1,
"dsize > 1 is not supported for GLWEToLWEKeyPrepared"
);
self.alloc_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) self.alloc_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum())
} }
@@ -94,11 +90,7 @@ where
1, 1,
"rank_out > 1 is not supported for GLWEToLWEKeyPrepared" "rank_out > 1 is not supported for GLWEToLWEKeyPrepared"
); );
debug_assert_eq!( debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for GLWEToLWEKeyPrepared");
infos.dsize().0,
1,
"dsize > 1 is not supported for GLWEToLWEKeyPrepared"
);
self.bytes_of_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) self.bytes_of_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum())
} }

View File

@@ -67,21 +67,9 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
debug_assert_eq!( debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey");
infos.dsize().0, debug_assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey");
1, debug_assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey");
"dsize > 1 is not supported for LWESwitchingKey"
);
debug_assert_eq!(
infos.rank_in().0,
1,
"rank_in > 1 is not supported for LWESwitchingKey"
);
debug_assert_eq!(
infos.rank_out().0,
1,
"rank_out > 1 is not supported for LWESwitchingKey"
);
self.alloc_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum()) self.alloc_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum())
} }
@@ -93,21 +81,9 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
debug_assert_eq!( debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWESwitchingKey");
infos.dsize().0, debug_assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWESwitchingKey");
1, debug_assert_eq!(infos.rank_out().0, 1, "rank_out > 1 is not supported for LWESwitchingKey");
"dsize > 1 is not supported for LWESwitchingKey"
);
debug_assert_eq!(
infos.rank_in().0,
1,
"rank_in > 1 is not supported for LWESwitchingKey"
);
debug_assert_eq!(
infos.rank_out().0,
1,
"rank_out > 1 is not supported for LWESwitchingKey"
);
self.bytes_of_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum()) self.bytes_of_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.dnum())
} }

View File

@@ -69,16 +69,8 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
debug_assert_eq!( debug_assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWEKey");
infos.rank_in().0, debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWEKey");
1,
"rank_in > 1 is not supported for LWEToGLWEKey"
);
debug_assert_eq!(
infos.dsize().0,
1,
"dsize > 1 is not supported for LWEToGLWEKey"
);
self.alloc_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) self.alloc_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum())
} }
@@ -90,16 +82,8 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
debug_assert_eq!( debug_assert_eq!(infos.rank_in().0, 1, "rank_in > 1 is not supported for LWEToGLWEKey");
infos.rank_in().0, debug_assert_eq!(infos.dsize().0, 1, "dsize > 1 is not supported for LWEToGLWEKey");
1,
"rank_in > 1 is not supported for LWEToGLWEKey"
);
debug_assert_eq!(
infos.dsize().0,
1,
"dsize > 1 is not supported for LWEToGLWEKey"
);
self.bytes_of_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) self.bytes_of_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum())
} }

View File

@@ -5,6 +5,7 @@ mod glwe;
mod glwe_automorphism_key; mod glwe_automorphism_key;
mod glwe_public_key; mod glwe_public_key;
mod glwe_secret; mod glwe_secret;
mod glwe_secret_tensor;
mod glwe_switching_key; mod glwe_switching_key;
mod glwe_tensor_key; mod glwe_tensor_key;
mod glwe_to_lwe_key; mod glwe_to_lwe_key;
@@ -18,6 +19,7 @@ pub use glwe::*;
pub use glwe_automorphism_key::*; pub use glwe_automorphism_key::*;
pub use glwe_public_key::*; pub use glwe_public_key::*;
pub use glwe_secret::*; pub use glwe_secret::*;
pub use glwe_secret_tensor::*;
pub use glwe_switching_key::*; pub use glwe_switching_key::*;
pub use glwe_tensor_key::*; pub use glwe_tensor_key::*;
pub use glwe_to_lwe_key::*; pub use glwe_to_lwe_key::*;

View File

@@ -78,13 +78,7 @@ where
let dsize: usize = res.dsize().into(); let dsize: usize = res.dsize().into();
let (mut pt, scratch_1) = scratch.take_glwe_plaintext(res); let (mut pt, scratch_1) = scratch.take_glwe_plaintext(res);
pt.data_mut().zero(); pt.data_mut().zero();
self.vec_znx_add_scalar_inplace( self.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + res_row * dsize, pt_want, res_col);
&mut pt.data,
0,
(dsize - 1) + res_row * dsize,
pt_want,
res_col,
);
self.glwe_noise(&res.at(res_row, res_col), &pt, sk_prepared, scratch_1) self.glwe_noise(&res.at(res_row, res_col), &pt, sk_prepared, scratch_1)
} }
} }

View File

@@ -102,7 +102,7 @@ where
self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0);
self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, res_col - 1); self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, res_col - 1);
let pt_big = self.vec_znx_idft_apply_consume(pt_dft); let pt_big = self.vec_znx_idft_apply_consume(pt_dft);
self.vec_znx_big_normalize(base2k, &mut pt.data, 0, base2k, &pt_big, 0, scratch_2); self.vec_znx_big_normalize(&mut pt.data, base2k, 0, 0, &pt_big, base2k, 0, scratch_2);
} }
self.glwe_noise(&res.at(res_row, res_col), &pt, sk_prepared, scratch_1) self.glwe_noise(&res.at(res_row, res_col), &pt, sk_prepared, scratch_1)

View File

@@ -38,10 +38,7 @@ where
where where
A: GLWEInfos, A: GLWEInfos,
{ {
GLWEPlaintext::bytes_of_from_infos(infos) GLWEPlaintext::bytes_of_from_infos(infos) + self.glwe_normalize_tmp_bytes().max(self.glwe_decrypt_tmp_bytes(infos))
+ self
.glwe_normalize_tmp_bytes()
.max(self.glwe_decrypt_tmp_bytes(infos))
} }
fn glwe_noise<R, P, S>(&self, res: &R, pt_want: &P, sk_prepared: &S, scratch: &mut Scratch<BE>) -> Stats fn glwe_noise<R, P, S>(&self, res: &R, pt_want: &P, sk_prepared: &S, scratch: &mut Scratch<BE>) -> Stats

View File

@@ -1,80 +1,568 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy, CnvPVecBytesOf, Convolution, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigAddSmallInplace,
VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize, VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero,
}, },
layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos}, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxBig},
reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes, reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes,
}; };
use crate::{ use crate::{
ScratchTakeCore, GGLWEProduct, ScratchTakeCore,
layouts::{ layouts::{
GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos, Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWEPlaintext, GLWETensor, GLWETensorKeyPrepared, GLWEToMut, GLWEToRef, LWEInfos,
TorusPrecision, TorusPrecision,
}, },
}; };
pub trait GLWETensoring<BE: Backend> pub trait GLWEMulConst<BE: Backend> {
fn glwe_mul_const_tmp_bytes<R, A>(&self, res: &R, res_offset: usize, a: &A, b_size: usize) -> usize
where
R: GLWEInfos,
A: GLWEInfos;
fn glwe_mul_const<R, A>(&self, res: &mut GLWE<R>, res_offset: usize, a: &GLWE<A>, b: &[i64], scratch: &mut Scratch<BE>)
where
R: DataMut,
A: DataRef;
fn glwe_mul_const_inplace<R>(&self, res: &mut GLWE<R>, res_offset: usize, b: &[i64], scratch: &mut Scratch<BE>)
where
R: DataMut;
}
impl<BE: Backend> GLWEMulConst<BE> for Module<BE>
where where
Self: BivariateTensoring<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigNormalize<BE>, Self: Convolution<BE> + VecZnxBigBytesOf + VecZnxBigNormalize<BE> + VecZnxBigNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
/// res = (a (x) b) * 2^{k * a_base2k} fn glwe_mul_const_tmp_bytes<R, A>(&self, res: &R, res_offset: usize, a: &A, b_size: usize) -> usize
///
/// # Requires
/// * a.base2k() == b.base2k()
/// * res.cols() >= a.cols() + b.cols() - 1
///
/// # Behavior
/// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k)
fn glwe_tensor<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
R: GLWETensorToMut, R: GLWEInfos,
A: GLWEToRef, A: GLWEInfos,
B: GLWEPreparedToRef<BE>,
{ {
let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut(); let a_base2k: usize = a.base2k().as_usize();
let a: &GLWE<&[u8]> = &a.to_ref(); let res_base2k: usize = res.base2k().as_usize();
let b: &GLWEPrepared<&[u8], BE> = &b.to_ref();
assert_eq!(a.base2k(), b.base2k()); let res_size: usize = (res.size() * res_base2k).div_ceil(a_base2k);
assert_eq!(a.rank(), res.rank()); let res_big: usize = self.bytes_of_vec_znx_big(1, res_size);
let cnv: usize = self.cnv_by_const_apply_tmp_bytes(res_size, res_offset, a.size(), b_size);
let normalize: usize = self.vec_znx_big_normalize_tmp_bytes();
let res_cols: usize = res.data.cols(); res_big + cnv.max(normalize)
}
// Get tmp buffer of min precision between a_prec * b_prec and res_prec fn glwe_mul_const<R, A>(&self, res: &mut GLWE<R>, res_offset: usize, a: &GLWE<A>, b: &[i64], scratch: &mut Scratch<BE>)
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize); where
R: DataMut,
A: DataRef,
{
assert_eq!(res.rank(), a.rank());
// DFT(res) = DFT(a) (x) DFT(b) let cols: usize = res.rank().as_usize() + 1;
self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1); let a_base2k: usize = a.base2k().as_usize();
let res_base2k: usize = res.base2k().as_usize();
// res = IDFT(res) let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k {
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); (0, -((a_base2k - (res_offset % a_base2k)) as i64))
} else {
((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64)
};
// Normalize and switches basis if required let res_dft_size = res
for res_col in 0..res_cols { .k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a.size() + b.len() - res_offset_hi);
let (mut res_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res_dft_size);
for i in 0..cols {
self.cnv_by_const_apply(&mut res_big, res_offset_hi, 0, a.data(), i, b, scratch_1);
self.vec_znx_big_normalize(res.data_mut(), res_base2k, res_offset_lo, i, &res_big, a_base2k, 0, scratch_1);
}
}
fn glwe_mul_const_inplace<R>(&self, res: &mut GLWE<R>, res_offset: usize, b: &[i64], scratch: &mut Scratch<BE>)
where
R: DataMut,
{
let cols: usize = res.rank().as_usize() + 1;
let res_base2k: usize = res.base2k().as_usize();
let (res_offset_hi, res_offset_lo) = if res_offset < res_base2k {
(0, -((res_base2k - (res_offset % res_base2k)) as i64))
} else {
((res_offset / res_base2k).saturating_sub(1), (res_offset % res_base2k) as i64)
};
let (mut res_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res.size());
for i in 0..cols {
self.cnv_by_const_apply(&mut res_big, res_offset_hi, 0, res.data(), i, b, scratch_1);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
res.base2k().into(), res.data_mut(),
&mut res.data, res_base2k,
res_col, res_offset_lo,
a.base2k().into(), i,
&res_big, &res_big,
res_col, res_base2k,
0,
scratch_1, scratch_1,
); );
} }
} }
}
// fn glwe_relinearize<R, A, T>(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch<BE>) impl<BE: Backend> GLWEMulPlain<BE> for Module<BE>
// where where
// R: GLWEToRef, Self: Sized
// A: GLWETensorToRef, + ModuleN
// T: GLWETensorKeyPreparedToRef<BE>, + CnvPVecBytesOf
// { + VecZnxDftBytesOf
// } + VecZnxIdftApplyConsume<BE>
+ VecZnxBigNormalize<BE>
+ Convolution<BE>
+ VecZnxBigNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn glwe_mul_plain_tmp_bytes<R, A, B>(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GLWEInfos,
{
let ab_base2k: Base2K = a.base2k();
assert_eq!(b.base2k(), ab_base2k);
let cols: usize = res.rank().as_usize() + 1;
let a_size: usize = a.size();
let b_size: usize = b.size();
let res_size: usize = res.size();
let cnv_pvec: usize = self.bytes_of_cnv_pvec_left(cols, a_size) + self.bytes_of_cnv_pvec_right(1, b_size);
let cnv_prep: usize = self
.cnv_prepare_left_tmp_bytes(a_size, a_size)
.max(self.cnv_prepare_right_tmp_bytes(a_size, a_size));
let cnv_apply: usize = self.cnv_apply_dft_tmp_bytes(res_size, res_offset, a_size, b_size);
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a_size + b_size - res_offset / ab_base2k.as_usize());
let res_dft: usize = self.bytes_of_vec_znx_dft(1, res_dft_size);
let norm: usize = self.vec_znx_big_normalize_tmp_bytes();
cnv_pvec + cnv_prep + res_dft + cnv_apply.max(norm)
}
fn glwe_mul_plain<R, A, B>(
&self,
res: &mut GLWE<R>,
res_offset: usize,
a: &GLWE<A>,
b: &GLWEPlaintext<B>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef,
{
assert_eq!(res.rank(), a.rank());
let a_base2k: usize = a.base2k().as_usize();
assert_eq!(b.base2k().as_usize(), a_base2k);
let res_base2k: usize = res.base2k().as_usize();
let cols: usize = res.rank().as_usize() + 1;
let (mut a_prep, scratch_1) = scratch.take_cnv_pvec_left(self, cols, a.size());
let (mut b_prep, scratch_2) = scratch_1.take_cnv_pvec_right(self, 1, b.size());
self.cnv_prepare_left(&mut a_prep, a.data(), scratch_2);
self.cnv_prepare_right(&mut b_prep, b.data(), scratch_2);
let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k {
(0, -((a_base2k - (res_offset % a_base2k)) as i64))
} else {
((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64)
};
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a.size() + b.size() - res_offset_hi);
for i in 0..cols {
let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res_dft_size);
self.cnv_apply_dft(&mut res_dft, res_offset_hi, 0, &a_prep, i, &b_prep, 0, scratch_3);
let res_big = self.vec_znx_idft_apply_consume(res_dft);
self.vec_znx_big_normalize(res.data_mut(), res_base2k, res_offset_lo, i, &res_big, a_base2k, 0, scratch_3);
}
}
fn glwe_mul_plain_inplace<R, A>(&self, res: &mut GLWE<R>, res_offset: usize, a: &GLWEPlaintext<A>, scratch: &mut Scratch<BE>)
where
R: DataMut,
A: DataRef,
{
assert_eq!(res.rank(), a.rank());
let a_base2k: usize = a.base2k().as_usize();
let res_base2k: usize = res.base2k().as_usize();
assert_eq!(res_base2k, a_base2k);
let cols: usize = res.rank().as_usize() + 1;
let (mut res_prep, scratch_1) = scratch.take_cnv_pvec_left(self, cols, res.size());
let (mut a_prep, scratch_2) = scratch_1.take_cnv_pvec_right(self, 1, a.size());
self.cnv_prepare_left(&mut res_prep, res.data(), scratch_2);
self.cnv_prepare_right(&mut a_prep, a.data(), scratch_2);
let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k {
(0, -((a_base2k - (res_offset % a_base2k)) as i64))
} else {
((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64)
};
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a.size() + res.size() - res_offset_hi);
for i in 0..cols {
let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res_dft_size);
self.cnv_apply_dft(&mut res_dft, res_offset, 0, &res_prep, i, &a_prep, 0, scratch_3);
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
self.vec_znx_big_normalize(res.data_mut(), res_base2k, res_offset_lo, i, &res_big, a_base2k, 0, scratch_3);
}
}
}
pub trait GLWEMulPlain<BE: Backend> {
fn glwe_mul_plain_tmp_bytes<R, A, B>(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GLWEInfos;
fn glwe_mul_plain<R, A, B>(
&self,
res: &mut GLWE<R>,
res_offset: usize,
a: &GLWE<A>,
b: &GLWEPlaintext<B>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef;
fn glwe_mul_plain_inplace<R, A>(&self, res: &mut GLWE<R>, res_offset: usize, a: &GLWEPlaintext<A>, scratch: &mut Scratch<BE>)
where
R: DataMut,
A: DataRef;
}
pub trait GLWETensoring<BE: Backend> {
fn glwe_tensor_apply_tmp_bytes<R, A, B>(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GLWEInfos;
/// res = (a (x) b) * 2^{res_offset * a_base2k}
///
/// # Requires
/// * a.base2k() == b.base2k()
/// * a.rank() == b.rank()
///
/// # Behavior
/// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k)
fn glwe_tensor_apply<R, A, B>(
&self,
res: &mut GLWETensor<R>,
res_offset: usize,
a: &GLWE<A>,
b: &GLWE<B>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef;
fn glwe_tensor_relinearize<R, A, B>(
&self,
res: &mut GLWE<R>,
a: &GLWETensor<A>,
tsk: &GLWETensorKeyPrepared<B, BE>,
tsk_size: usize,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef;
fn glwe_tensor_relinearize_tmp_bytes<R, A, B>(&self, res: &R, a: &A, tsk: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGLWEInfos;
}
impl<BE: Backend> GLWETensoring<BE> for Module<BE>
where
Self: Sized
+ ModuleN
+ CnvPVecBytesOf
+ VecZnxDftBytesOf
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigNormalize<BE>
+ Convolution<BE>
+ VecZnxSubInplace
+ VecZnxNegate
+ VecZnxAddInplace
+ VecZnxBigNormalizeTmpBytes
+ VecZnxCopy
+ VecZnxNormalize<BE>
+ VecZnxDftApply<BE>
+ GGLWEProduct<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn glwe_tensor_apply_tmp_bytes<R, A, B>(&self, res: &R, res_offset: usize, a: &A, b: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GLWEInfos,
{
let ab_base2k: Base2K = a.base2k();
assert_eq!(b.base2k(), ab_base2k);
let cols: usize = res.rank().as_usize() + 1;
let a_size: usize = a.size();
let b_size: usize = b.size();
let res_size: usize = res.size();
let cnv_pvec: usize = self.bytes_of_cnv_pvec_left(cols, a_size) + self.bytes_of_cnv_pvec_right(cols, b_size);
let cnv_prep: usize = self
.cnv_prepare_left_tmp_bytes(a_size, a_size)
.max(self.cnv_prepare_right_tmp_bytes(a_size, a_size));
let cnv_apply: usize = self
.cnv_apply_dft_tmp_bytes(res_size, res_offset, a_size, b_size)
.max(self.cnv_pairwise_apply_dft_tmp_bytes(res_size, res_offset, a_size, b_size));
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a_size + b_size - res_offset / ab_base2k.as_usize());
let res_dft: usize = self.bytes_of_vec_znx_dft(1, res_dft_size);
let tmp: usize = VecZnx::bytes_of(self.n(), 1, res.size());
let norm: usize = self.vec_znx_big_normalize_tmp_bytes();
cnv_pvec + cnv_prep + res_dft + cnv_apply.max(tmp + norm)
}
fn glwe_tensor_relinearize_tmp_bytes<R, A, B>(&self, res: &R, a: &A, tsk: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGLWEInfos,
{
let a_base2k: usize = a.base2k().into();
let key_base2k: usize = tsk.base2k().into();
let res_base2k: usize = res.base2k().into();
let cols: usize = tsk.rank_out().as_usize() + 1;
let pairs: usize = tsk.rank_in().as_usize();
let a_dft_size: usize = (a.size() * a_base2k).div_ceil(key_base2k);
let a_dft = self.bytes_of_vec_znx_dft(pairs, a_dft_size);
let a_conv: usize = if a_base2k != key_base2k || res_base2k != key_base2k {
VecZnx::bytes_of(self.n(), 1, a_dft_size) + self.vec_znx_normalize_tmp_bytes()
} else {
0
};
let res_dft: usize = self.bytes_of_vec_znx_dft(cols, tsk.size());
let gglwe_product: usize = self.gglwe_product_dft_tmp_bytes(res.size(), a_dft_size, tsk);
let big_normalize: usize = self.vec_znx_big_normalize_tmp_bytes();
a_dft.max(a_conv + big_normalize) + res_dft + gglwe_product.max(a_conv).max(big_normalize)
}
fn glwe_tensor_relinearize<R, A, B>(
&self,
res: &mut GLWE<R>,
a: &GLWETensor<A>,
tsk: &GLWETensorKeyPrepared<B, BE>,
tsk_size: usize,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef,
{
let a_base2k: usize = a.base2k().into();
let key_base2k: usize = tsk.base2k().into();
let res_base2k: usize = res.base2k().into();
assert_eq!(res.rank(), tsk.rank_out());
assert_eq!(a.rank(), tsk.rank_out());
let cols: usize = tsk.rank_out().as_usize() + 1;
let pairs: usize = tsk.rank_in().as_usize();
let a_dft_size: usize = (a.size() * a_base2k).div_ceil(key_base2k);
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, pairs, a_dft_size);
if a_base2k != key_base2k {
let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_dft_size);
for i in 0..pairs {
self.vec_znx_normalize(&mut a_conv, key_base2k, 0, 0, a.data(), a_base2k, cols + i, scratch_2);
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0);
}
} else {
for i in 0..pairs {
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, a.data(), 0);
}
}
let (mut res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk_size); // Todo optimise
self.gglwe_product_dft(&mut res_dft, &a_dft, &tsk.0, scratch_2);
let mut res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
if res_base2k == key_base2k {
for i in 0..cols {
self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i);
}
} else {
let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), 1, a_dft_size);
for i in 0..cols {
self.vec_znx_normalize(&mut a_conv, key_base2k, 0, 0, a.data(), a_base2k, i, scratch_3);
self.vec_znx_big_add_small_inplace(&mut res_big, i, &a_conv, 0);
}
}
for i in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize(res.data_mut(), res_base2k, 0, i, &res_big, key_base2k, i, scratch_2);
}
}
fn glwe_tensor_apply<R, A, B>(
&self,
res: &mut GLWETensor<R>,
res_offset: usize,
a: &GLWE<A>,
b: &GLWE<B>,
scratch: &mut Scratch<BE>,
) where
R: DataMut,
A: DataRef,
B: DataRef,
{
let a_base2k: usize = a.base2k().as_usize();
assert_eq!(b.base2k().as_usize(), a_base2k);
let res_base2k: usize = res.base2k().as_usize();
let cols: usize = res.rank().as_usize() + 1;
let (mut a_prep, scratch_1) = scratch.take_cnv_pvec_left(self, cols, a.size());
let (mut b_prep, scratch_2) = scratch_1.take_cnv_pvec_right(self, cols, b.size());
self.cnv_prepare_left(&mut a_prep, a.data(), scratch_2);
self.cnv_prepare_right(&mut b_prep, b.data(), scratch_2);
// Example for rank=3
//
// (a0, a1, a2, a3) x (b0, b1, b2, a3)
// L L L L R R R R
//
// c(1) = a0 * b0 <- (L(a0) * R(b0))
// c(s1) = a0 * b1 + a1 * b0 <- (L(a0) + L(a1)) * (R(b0) + R(b1)) + NEG(L(a0) * R(b0)) + SUB(L(a1) * R(b1))
// c(s2) = a0 * b2 + a2 * b0 <- (L(a0) + L(a2)) * (R(b0) + R(b2)) + NEG(L(a0) * R(b0)) + SUB(L(a2) * R(b2))
// c(s3) = a0 * b3 + a3 * b0 <- (L(a0) + L(a3)) * (R(b0) + R(b3)) + NEG(L(a0) * R(b0)) + SUB(L(a3) * R(b3))
// c(s1^2) = a1 * b1 <- (L(a1) * R(b1))
// c(s1s2) = a1 * b2 + b2 * a1 <- (L(a1) + L(a2)) * (R(b1) + R(b2)) + NEG(L(a1) * R(b1)) + SUB(L(a2) * R(b2))
// c(s1s3) = a1 * b3 + b3 * a1 <- (L(a1) + L(a3)) * (R(b1) + R(b3)) + NEG(L(a1) * R(b1)) + SUB(L(a3) * R(b3))
// c(s2^2) = a2 * b2 <- (L(a2) * R(b2))
// c(s2s3) = a2 * b3 + a3 * b2 <- (L(a2) + L(a3)) * (R(b2) + R(b3)) + NEG(L(a2) * R(b2)) + SUB(L(a3) * R(b3))
// c(s3^2) = a3 * b3 <- (L(a3) * R(b3))
// Derive the offset. If res_offset < a_base2k, then we shift to a negative offset
// since the convolution doesn't support negative offset (yet).
let (res_offset_hi, res_offset_lo) = if res_offset < a_base2k {
(0, -((a_base2k - (res_offset % a_base2k)) as i64))
} else {
((res_offset / a_base2k).saturating_sub(1), (res_offset % a_base2k) as i64)
};
let res_dft_size = res
.k()
.as_usize()
.div_ceil(a.base2k().as_usize())
.min(a.size() + b.size() - res_offset_hi);
for i in 0..cols {
let col_i: usize = i * cols - (i * (i + 1) / 2);
let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res_dft_size);
self.cnv_apply_dft(&mut res_dft, res_offset_hi, 0, &a_prep, i, &b_prep, i, scratch_3);
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
let (mut tmp, scratch_4) = scratch_3.take_vec_znx(self.n(), 1, res_dft_size);
self.vec_znx_big_normalize(&mut tmp, res_base2k, res_offset_lo, 0, &res_big, a_base2k, 0, scratch_4);
self.vec_znx_copy(res.data_mut(), col_i + i, &tmp, 0);
// Pre-subtracts
// res[i!=j] = NEG(a[i] * b[i]) + SUB(a[j] * b[j])
for j in 0..cols {
if j != i {
if j < i {
let col_j = j * cols - (j * (j + 1) / 2);
self.vec_znx_sub_inplace(res.data_mut(), col_j + i, &tmp, 0);
} else {
self.vec_znx_negate(res.data_mut(), col_i + j, &tmp, 0);
}
}
}
}
for i in 0..cols {
let col_i: usize = i * cols - (i * (i + 1) / 2);
for j in i..cols {
if j != i {
// res_dft = (a[i] + a[j]) * (b[i] + b[j])
let (mut res_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, res.size());
self.cnv_pairwise_apply_dft(&mut res_dft, res_offset_hi, 0, &a_prep, &b_prep, i, j, scratch_3);
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
let (mut tmp, scratch_3) = scratch_3.take_vec_znx(self.n(), 1, res.size());
self.vec_znx_big_normalize(&mut tmp, res_base2k, res_offset_lo, 0, &res_big, a_base2k, 0, scratch_3);
self.vec_znx_add_inplace(res.data_mut(), col_i + j, &tmp, 0);
}
}
}
}
} }
pub trait GLWEAdd pub trait GLWEAdd
@@ -431,9 +919,7 @@ where
*tmp_slot = Some(tmp); *tmp_slot = Some(tmp);
// Get a mutable handle to the temp and normalize into it // Get a mutable handle to the temp and normalize into it
let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot.as_mut().expect("tmp_slot just set to Some, but found None");
.as_mut()
.expect("tmp_slot just set to Some, but found None");
self.glwe_normalize(tmp_ref, glwe, scratch2); self.glwe_normalize(tmp_ref, glwe, scratch2);
@@ -470,9 +956,7 @@ where
*tmp_slot = Some(tmp); *tmp_slot = Some(tmp);
// Get a mutable handle to the temp and normalize into it // Get a mutable handle to the temp and normalize into it
let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot.as_mut().expect("tmp_slot just set to Some, but found None");
.as_mut()
.expect("tmp_slot just set to Some, but found None");
self.glwe_normalize(tmp_ref, glwe, scratch2); self.glwe_normalize(tmp_ref, glwe, scratch2);
@@ -493,16 +977,10 @@ where
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank()); assert_eq!(res.rank(), a.rank());
let res_base2k = res.base2k().into();
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_normalize( self.vec_znx_normalize(res.data_mut(), res_base2k, 0, i, a.data(), a.base2k().into(), i, scratch);
res.base2k().into(),
res.data_mut(),
i,
a.base2k().into(),
a.data(),
i,
scratch,
);
} }
} }

View File

@@ -20,6 +20,10 @@ mod poulpy_core {
glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk,
// GLWE Base2k Conversion // GLWE Base2k Conversion
glwe_base2k_conv => crate::tests::test_suite::test_glwe_base2k_conversion, glwe_base2k_conv => crate::tests::test_suite::test_glwe_base2k_conversion,
// GLWE Tensoring
test_glwe_tensoring => crate::tests::test_suite::glwe_tensor::test_glwe_tensoring,
test_glwe_mul_plain => crate::tests::test_suite::glwe_tensor::test_glwe_mul_plain,
test_glwe_mul_const => crate::tests::test_suite::glwe_tensor::test_glwe_mul_const,
// GLWE Keyswitch // GLWE Keyswitch
glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch,
glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace,
@@ -88,6 +92,10 @@ mod poulpy_core {
glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk,
// GLWE Base2k Conversion // GLWE Base2k Conversion
glwe_base2k_conv => crate::tests::test_suite::test_glwe_base2k_conversion, glwe_base2k_conv => crate::tests::test_suite::test_glwe_base2k_conversion,
// GLWE Tensoring
test_glwe_tensoring => crate::tests::test_suite::glwe_tensor::test_glwe_tensoring,
test_glwe_mul_plain => crate::tests::test_suite::glwe_tensor::test_glwe_mul_plain,
test_glwe_mul_const => crate::tests::test_suite::glwe_tensor::test_glwe_mul_const,
// GLWE Keyswitch // GLWE Keyswitch
glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch,
glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace,

View File

@@ -29,27 +29,27 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let base2k_out: usize = base2k_in; // MUST BE SAME let out_base2k: usize = in_base2k; // MUST BE SAME
let k_in: usize = 102; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key); let max_dsize: usize = k_in.div_ceil(key_base2k);
let p0: i64 = -1; let p0: i64 = -1;
let p1: i64 = -5; let p1: i64 = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_in + base2k_key * dsize; let k_ksk: usize = k_in + key_base2k * dsize;
let k_out: usize = k_ksk; // Better capture noise. let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dsize_in: usize = 1; let dsize_in: usize = 1;
let dnum_in: usize = k_in / base2k_in; let dnum_in: usize = k_in / in_base2k;
let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let dnum_ksk: usize = k_in.div_ceil(key_base2k * dsize);
let auto_key_in_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_in_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -58,7 +58,7 @@ where
let auto_key_out_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_out_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -67,7 +67,7 @@ where
let auto_key_apply_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_apply_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -84,10 +84,7 @@ where
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos) GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos)
.max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply_infos))
module,
&auto_key_apply_infos,
))
.max(GLWEAutomorphismKey::automorphism_tmp_bytes( .max(GLWEAutomorphismKey::automorphism_tmp_bytes(
module, module,
&auto_key_out_infos, &auto_key_out_infos,
@@ -100,24 +97,10 @@ where
sk.fill_ternary_prob(0.5, &mut source_xs); sk.fill_ternary_prob(0.5, &mut source_xs);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
auto_key_in.encrypt_sk( auto_key_in.encrypt_sk(module, p0, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
p0,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2 // gglwe_{s2}(s1) -> s1 -> s2
auto_key_apply.encrypt_sk( auto_key_apply.encrypt_sk(module, p1, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
p1,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut auto_key_apply_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> = let mut auto_key_apply_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> =
GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_infos); GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_infos);
@@ -125,12 +108,7 @@ where
auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow());
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
auto_key_out.automorphism( auto_key_out.automorphism(module, &auto_key_in, &auto_key_apply_prepared, scratch.borrow());
module,
&auto_key_in,
&auto_key_apply_prepared,
scratch.borrow(),
);
let mut sk_auto: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&auto_key_out_infos); let mut sk_auto: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&auto_key_out_infos);
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
@@ -152,7 +130,7 @@ where
k_ksk, k_ksk,
dnum_ksk, dnum_ksk,
dsize, dsize,
base2k_key, key_base2k,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
@@ -171,11 +149,7 @@ where
.std() .std()
.log2(); .log2();
assert!( assert!(noise_have < max_noise + 0.5, "{noise_have} > {}", max_noise + 0.5);
noise_have < max_noise + 0.5,
"{noise_have} > {}",
max_noise + 0.5
);
} }
} }
} }
@@ -196,26 +170,26 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 17; let out_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let k_out: usize = 102; let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key); let max_dsize: usize = k_out.div_ceil(key_base2k);
let p0: i64 = -1; let p0: i64 = -1;
let p1: i64 = -5; let p1: i64 = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_out + base2k_key * dsize; let k_ksk: usize = k_out + key_base2k * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dsize_in: usize = 1; let dsize_in: usize = 1;
let dnum_in: usize = k_out / base2k_out; let dnum_in: usize = k_out / out_base2k;
let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let dnum_ksk: usize = k_out.div_ceil(key_base2k * dsize);
let auto_key_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -224,7 +198,7 @@ where
let auto_key_apply_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_apply_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -248,24 +222,10 @@ where
sk.fill_ternary_prob(0.5, &mut source_xs); sk.fill_ternary_prob(0.5, &mut source_xs);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
auto_key.encrypt_sk( auto_key.encrypt_sk(module, p0, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
p0,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2 // gglwe_{s2}(s1) -> s1 -> s2
auto_key_apply.encrypt_sk( auto_key_apply.encrypt_sk(module, p1, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
p1,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut auto_key_apply_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> = let mut auto_key_apply_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> =
GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_layout); GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_apply_layout);
@@ -296,7 +256,7 @@ where
k_ksk, k_ksk,
dnum_ksk, dnum_ksk,
dsize, dsize,
base2k_key, key_base2k,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
@@ -315,11 +275,7 @@ where
.std() .std()
.log2(); .log2();
assert!( assert!(noise_have < max_noise + 0.5, "{noise_have} {}", max_noise + 0.5);
noise_have < max_noise + 0.5,
"{noise_have} {}",
max_noise + 0.5
);
} }
} }
} }

View File

@@ -29,28 +29,28 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let base2k_out: usize = base2k_in; // MUST BE SAME let out_base2k: usize = in_base2k; // MUST BE SAME
let k_in: usize = 102; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key); let max_dsize: usize = k_in.div_ceil(key_base2k);
let p: i64 = -5; let p: i64 = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_in + base2k_key * dsize; let k_ksk: usize = k_in + key_base2k * dsize;
let k_tsk: usize = k_ksk; let k_tsk: usize = k_ksk;
let k_out: usize = k_ksk; // Better capture noise. let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dnum_in: usize = k_in / base2k_in; let dnum_in: usize = k_in / in_base2k;
let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let dnum_ksk: usize = k_in.div_ceil(key_base2k * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_in_layout: GGSWLayout = GGSWLayout { let ggsw_in_layout: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -59,7 +59,7 @@ where
let ggsw_out_layout: GGSWLayout = GGSWLayout { let ggsw_out_layout: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -68,7 +68,7 @@ where
let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_tsk.into(), k: k_tsk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -77,7 +77,7 @@ where
let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -109,21 +109,8 @@ where
let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk);
sk_prepared.prepare(module, &sk); sk_prepared.prepare(module, &sk);
auto_key.encrypt_sk( auto_key.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module, tsk.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
p,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
tsk.encrypt_sk(
module,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
pt_scalar.fill_ternary_hw(0, n, &mut source_xs); pt_scalar.fill_ternary_hw(0, n, &mut source_xs);
@@ -143,20 +130,14 @@ where
let mut tsk_prepared: GGLWEToGGSWKeyPrepared<Vec<u8>, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); let mut tsk_prepared: GGLWEToGGSWKeyPrepared<Vec<u8>, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk);
tsk_prepared.prepare(module, &tsk, scratch.borrow()); tsk_prepared.prepare(module, &tsk, scratch.borrow());
ct_out.automorphism( ct_out.automorphism(module, &ct_in, &auto_key_prepared, &tsk_prepared, scratch.borrow());
module,
&ct_in,
&auto_key_prepared,
&tsk_prepared,
scratch.borrow(),
);
module.vec_znx_automorphism_inplace(p, &mut pt_scalar.as_vec_znx_mut(), 0, scratch.borrow()); module.vec_znx_automorphism_inplace(p, &mut pt_scalar.as_vec_znx_mut(), 0, scratch.borrow());
let max_noise = |col_j: usize| -> f64 { let max_noise = |col_j: usize| -> f64 {
noise_ggsw_keyswitch( noise_ggsw_keyswitch(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
col_j, col_j,
var_xs, var_xs,
0f64, 0f64,
@@ -199,25 +180,25 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 17; let out_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let k_out: usize = 102; let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key); let max_dsize: usize = k_out.div_ceil(key_base2k);
let p: i64 = -1; let p: i64 = -1;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_out + base2k_key * dsize; let k_ksk: usize = k_out + key_base2k * dsize;
let k_tsk: usize = k_ksk; let k_tsk: usize = k_ksk;
let n: usize = module.n(); let n: usize = module.n();
let dnum_in: usize = k_out / base2k_out; let dnum_in: usize = k_out / out_base2k;
let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let dnum_ksk: usize = k_out.div_ceil(key_base2k * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_out_layout: GGSWLayout = GGSWLayout { let ggsw_out_layout: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -226,7 +207,7 @@ where
let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_tsk.into(), k: k_tsk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -235,7 +216,7 @@ where
let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -266,21 +247,8 @@ where
let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk);
sk_prepared.prepare(module, &sk); sk_prepared.prepare(module, &sk);
auto_key.encrypt_sk( auto_key.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module, tsk.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
p,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
tsk.encrypt_sk(
module,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
pt_scalar.fill_ternary_hw(0, n, &mut source_xs); pt_scalar.fill_ternary_hw(0, n, &mut source_xs);
@@ -307,7 +275,7 @@ where
let max_noise = |col_j: usize| -> f64 { let max_noise = |col_j: usize| -> f64 {
noise_ggsw_keyswitch( noise_ggsw_keyswitch(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
col_j, col_j,
var_xs, var_xs,
0f64, 0f64,
@@ -327,10 +295,7 @@ where
.std() .std()
.log2(); .log2();
let noise_max: f64 = max_noise(col); let noise_max: f64 = max_noise(col);
assert!( assert!(noise_have <= noise_max, "noise_have:{noise_have} > noise_max:{noise_max}",)
noise_have <= noise_max,
"noise_have:{noise_have} > noise_max:{noise_max}",
)
} }
} }
} }

View File

@@ -30,37 +30,37 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let base2k_out: usize = 15; let out_base2k: usize = 15;
let k_in: usize = 102; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key); let max_dsize: usize = k_in.div_ceil(key_base2k);
let p: i64 = -5; let p: i64 = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_in + base2k_key * dsize; let k_ksk: usize = k_in + key_base2k * dsize;
let k_out: usize = k_ksk; // Better capture noise. let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_in.div_ceil(base2k_key * dsize); let dnum: usize = k_in.div_ceil(key_base2k * dsize);
let ct_in_infos: GLWELayout = GLWELayout { let ct_in_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_in.into(), k: k_in.into(),
rank: rank.into(), rank: rank.into(),
}; };
let ct_out_infos: GLWELayout = GLWELayout { let ct_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
rank: rank.into(), rank: rank.into(),
}; };
let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_out.into(), k: k_out.into(),
rank: rank.into(), rank: rank.into(),
dnum: dnum.into(), dnum: dnum.into(),
@@ -77,7 +77,7 @@ where
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
module.vec_znx_fill_uniform(base2k_in, &mut pt_in.data, 0, &mut source_xa); module.vec_znx_fill_uniform(in_base2k, &mut pt_in.data, 0, &mut source_xa);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey)
@@ -92,23 +92,9 @@ where
let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk);
sk_prepared.prepare(module, &sk); sk_prepared.prepare(module, &sk);
autokey.encrypt_sk( autokey.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
p,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct_in.encrypt_sk( ct_in.encrypt_sk(module, &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&pt_in,
&sk_prepared,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut autokey_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> = let mut autokey_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> =
GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &autokey_infos); GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &autokey_infos);
@@ -121,7 +107,7 @@ where
k_ksk, k_ksk,
dnum, dnum,
max_dsize, max_dsize,
base2k_key, key_base2k,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
@@ -135,13 +121,7 @@ where
module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow()); module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow());
module.vec_znx_automorphism_inplace(p, &mut pt_out.data, 0, scratch.borrow()); module.vec_znx_automorphism_inplace(p, &mut pt_out.data, 0, scratch.borrow());
assert!( assert!(ct_out.noise(module, &pt_out, &sk_prepared, scratch.borrow()).std().log2() <= max_noise + 1.0)
ct_out
.noise(module, &pt_out, &sk_prepared, scratch.borrow())
.std()
.log2()
<= max_noise + 1.0
)
} }
} }
} }
@@ -161,29 +141,29 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 17; let out_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let k_out: usize = 102; let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key); let max_dsize: usize = k_out.div_ceil(key_base2k);
let p: i64 = -5; let p: i64 = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_out + base2k_key * dsize; let k_ksk: usize = k_out + key_base2k * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_out.div_ceil(base2k_key * dsize); let dnum: usize = k_out.div_ceil(key_base2k * dsize);
let ct_out_infos: GLWELayout = GLWELayout { let ct_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
rank: rank.into(), rank: rank.into(),
}; };
let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
rank: rank.into(), rank: rank.into(),
dnum: dnum.into(), dnum: dnum.into(),
@@ -198,7 +178,7 @@ where
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
module.vec_znx_fill_uniform(base2k_out, &mut pt_want.data, 0, &mut source_xa); module.vec_znx_fill_uniform(out_base2k, &mut pt_want.data, 0, &mut source_xa);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey)
@@ -213,14 +193,7 @@ where
let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk);
sk_prepared.prepare(module, &sk); sk_prepared.prepare(module, &sk);
autokey.encrypt_sk( autokey.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
p,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct.encrypt_sk( ct.encrypt_sk(
module, module,
@@ -242,7 +215,7 @@ where
k_ksk, k_ksk,
dnum, dnum,
dsize, dsize,
base2k_key, key_base2k,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
@@ -255,12 +228,7 @@ where
module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow());
assert!( assert!(ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2() <= max_noise + 1.0)
ct.noise(module, &pt_want, &sk_prepared, scratch.borrow())
.std()
.log2()
<= max_noise + 1.0
)
} }
} }
} }

View File

@@ -65,33 +65,19 @@ where
let pt_in: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_infos_in); let pt_in: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_infos_in);
let pt_out: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_infos_out); let pt_out: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_infos_out);
ct_in.encrypt_sk( ct_in.encrypt_sk(module, &pt_in, &sk_prep, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&pt_in,
&sk_prep,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut data: Vec<Float> = (0..module.n()).map(|_| Float::with_val(128, 0)).collect(); let mut data: Vec<Float> = (0..module.n()).map(|_| Float::with_val(128, 0)).collect();
ct_in ct_in.data().decode_vec_float(ct_in.base2k().into(), 0, &mut data);
.data()
.decode_vec_float(ct_in.base2k().into(), 0, &mut data);
ct_out.fill_uniform(ct_out.base2k().into(), &mut source_xa); ct_out.fill_uniform(ct_out.base2k().into(), &mut source_xa);
module.glwe_normalize(&mut ct_out, &ct_in, scratch.borrow()); module.glwe_normalize(&mut ct_out, &ct_in, scratch.borrow());
let mut data_conv: Vec<Float> = (0..module.n()).map(|_| Float::with_val(128, 0)).collect(); let mut data_conv: Vec<Float> = (0..module.n()).map(|_| Float::with_val(128, 0)).collect();
ct_out ct_out.data().decode_vec_float(ct_out.base2k().into(), 0, &mut data_conv);
.data()
.decode_vec_float(ct_out.base2k().into(), 0, &mut data_conv);
assert!( assert!(
ct_out ct_out.noise(module, &pt_out, &sk_prep, scratch.borrow()).std().log2()
.noise(module, &pt_out, &sk_prep, scratch.borrow())
.std()
.log2()
<= -(ct_out.k().as_u32() as f64) + SIGMA.log2() + 0.50 <= -(ct_out.k().as_u32() as f64) + SIGMA.log2() + 0.50
) )
} }
@@ -162,14 +148,7 @@ where
lwe_pt.encode_i64(data, k_lwe_pt); lwe_pt.encode_i64(data, k_lwe_pt);
let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos); let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos);
lwe_ct.encrypt_sk( lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&lwe_pt,
&sk_lwe,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ksk: LWEToGLWEKey<Vec<u8>> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos); let mut ksk: LWEToGLWEKey<Vec<u8>> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos);
@@ -195,11 +174,12 @@ where
let mut lwe_pt_conv = LWEPlaintext::alloc(glwe_pt.base2k(), lwe_pt.k()); let mut lwe_pt_conv = LWEPlaintext::alloc(glwe_pt.base2k(), lwe_pt.k());
module.vec_znx_normalize( module.vec_znx_normalize(
glwe_pt.base2k().as_usize(),
lwe_pt_conv.data_mut(), lwe_pt_conv.data_mut(),
glwe_pt.base2k().as_usize(),
0,
0, 0,
lwe_pt.base2k().as_usize(),
lwe_pt.data(), lwe_pt.data(),
lwe_pt.base2k().as_usize(),
0, 0,
scratch.borrow(), scratch.borrow(),
); );
@@ -287,14 +267,7 @@ where
let mut ksk: GLWEToLWEKey<Vec<u8>> = GLWEToLWEKey::alloc_from_infos(&glwe_to_lwe_infos); let mut ksk: GLWEToLWEKey<Vec<u8>> = GLWEToLWEKey::alloc_from_infos(&glwe_to_lwe_infos);
ksk.encrypt_sk( ksk.encrypt_sk(module, &sk_lwe, &sk_glwe, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&sk_lwe,
&sk_glwe,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos); let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos);
@@ -309,11 +282,12 @@ where
let mut glwe_pt_conv = GLWEPlaintext::alloc(glwe_ct.n(), lwe_pt.base2k(), lwe_pt.k()); let mut glwe_pt_conv = GLWEPlaintext::alloc(glwe_ct.n(), lwe_pt.base2k(), lwe_pt.k());
module.vec_znx_normalize( module.vec_znx_normalize(
lwe_pt.base2k().as_usize(),
glwe_pt_conv.data_mut(), glwe_pt_conv.data_mut(),
lwe_pt.base2k().as_usize(),
0,
0, 0,
glwe_ct.base2k().as_usize(),
glwe_pt.data(), glwe_pt.data(),
glwe_ct.base2k().as_usize(),
0, 0,
scratch.borrow(), scratch.borrow(),
); );

View File

@@ -53,23 +53,15 @@ where
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( let mut scratch: ScratchOwned<BE> =
module, &atk_infos, ScratchOwned::alloc(GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &atk_infos));
));
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&atk_infos); let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&atk_infos);
sk.fill_ternary_prob(0.5, &mut source_xs); sk.fill_ternary_prob(0.5, &mut source_xs);
let p = -5; let p = -5;
atk.encrypt_sk( atk.encrypt_sk(module, p, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
p,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut sk_out: GLWESecret<Vec<u8>> = sk.clone(); let mut sk_out: GLWESecret<Vec<u8>> = sk.clone();
(0..atk.rank().into()).for_each(|i| { (0..atk.rank().into()).for_each(|i| {
@@ -90,14 +82,7 @@ where
for col in 0..atk.rank().as_usize() { for col in 0..atk.rank().as_usize() {
assert!( assert!(
atk.key atk.key
.noise( .noise(module, row, col, &sk.data, &sk_out_prepared, scratch.borrow())
module,
row,
col,
&sk.data,
&sk_out_prepared,
scratch.borrow()
)
.std() .std()
.log2() .log2()
<= max_noise <= max_noise
@@ -145,9 +130,8 @@ where
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(GLWEAutomorphismKeyCompressed::encrypt_sk_tmp_bytes( let mut scratch: ScratchOwned<BE> =
module, &atk_infos, ScratchOwned::alloc(GLWEAutomorphismKeyCompressed::encrypt_sk_tmp_bytes(module, &atk_infos));
));
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&atk_infos); let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&atk_infos);
sk.fill_ternary_prob(0.5, &mut source_xs); sk.fill_ternary_prob(0.5, &mut source_xs);
@@ -180,14 +164,7 @@ where
for col in 0..atk.rank().as_usize() { for col in 0..atk.rank().as_usize() {
let noise_have = atk let noise_have = atk
.key .key
.noise( .noise(module, row, col, &sk.data, &sk_out_prepared, scratch.borrow())
module,
row,
col,
&sk.data,
&sk_out_prepared,
scratch.borrow(),
)
.std() .std()
.log2(); .log2();

View File

@@ -66,14 +66,7 @@ where
let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank_out.into());
sk_out_prepared.prepare(module, &sk_out); sk_out_prepared.prepare(module, &sk_out);
ksk.encrypt_sk( ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&sk_in,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let max_noise: f64 = SIGMA.log2() - (ksk.k().as_usize() as f64) + 0.5; let max_noise: f64 = SIGMA.log2() - (ksk.k().as_usize() as f64) + 0.5;
@@ -81,14 +74,7 @@ where
for col in 0..ksk.rank_in().as_usize() { for col in 0..ksk.rank_in().as_usize() {
let noise_have = ksk let noise_have = ksk
.key .key
.noise( .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow())
module,
row,
col,
&sk_in.data,
&sk_out_prepared,
scratch.borrow(),
)
.std() .std()
.log2(); .log2();
@@ -144,10 +130,8 @@ where
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes( let mut scratch: ScratchOwned<BE> =
module, ScratchOwned::alloc(GLWESwitchingKeyCompressed::encrypt_sk_tmp_bytes(module, &gglwe_infos));
&gglwe_infos,
));
let mut sk_in: GLWESecret<Vec<u8>> = GLWESecret::alloc(n.into(), rank_in.into()); let mut sk_in: GLWESecret<Vec<u8>> = GLWESecret::alloc(n.into(), rank_in.into());
sk_in.fill_ternary_prob(0.5, &mut source_xs); sk_in.fill_ternary_prob(0.5, &mut source_xs);
@@ -159,14 +143,7 @@ where
let seed_xa = [1u8; 32]; let seed_xa = [1u8; 32];
ksk_compressed.encrypt_sk( ksk_compressed.encrypt_sk(module, &sk_in, &sk_out, seed_xa, &mut source_xe, scratch.borrow());
module,
&sk_in,
&sk_out,
seed_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ksk: GLWESwitchingKey<Vec<u8>> = GLWESwitchingKey::alloc_from_infos(&gglwe_infos); let mut ksk: GLWESwitchingKey<Vec<u8>> = GLWESwitchingKey::alloc_from_infos(&gglwe_infos);
ksk.decompress(module, &ksk_compressed); ksk.decompress(module, &ksk_compressed);
@@ -177,14 +154,7 @@ where
for col in 0..ksk.rank_in().as_usize() { for col in 0..ksk.rank_in().as_usize() {
let noise_have = ksk let noise_have = ksk
.key .key
.noise( .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow())
module,
row,
col,
&sk_in.data,
&sk_out_prepared,
scratch.borrow(),
)
.std() .std()
.log2(); .log2();
@@ -269,14 +239,7 @@ where
for row in 0..ksk.dnum().as_usize() { for row in 0..ksk.dnum().as_usize() {
for col in 0..ksk.rank_in().as_usize() { for col in 0..ksk.rank_in().as_usize() {
let noise_have = ksk let noise_have = ksk
.noise( .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow())
module,
row,
col,
&sk_in.data,
&sk_out_prepared,
scratch.borrow(),
)
.std() .std()
.log2(); .log2();

View File

@@ -55,13 +55,7 @@ where
let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into()); let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into());
sk_prepared.prepare(module, &sk); sk_prepared.prepare(module, &sk);
key.encrypt_sk( key.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut sk_tensor: GLWESecretTensor<Vec<u8>> = GLWESecretTensor::alloc_from_infos(&sk); let mut sk_tensor: GLWESecretTensor<Vec<u8>> = GLWESecretTensor::alloc_from_infos(&sk);
sk_tensor.prepare(module, &sk, scratch.borrow()); sk_tensor.prepare(module, &sk, scratch.borrow());
@@ -72,12 +66,7 @@ where
for i in 0..rank { for i in 0..rank {
for j in 0..rank { for j in 0..rank {
module.vec_znx_copy( module.vec_znx_copy(&mut pt_want.as_vec_znx_mut(), j, &sk_tensor.at(i, j).as_vec_znx(), 0);
&mut pt_want.as_vec_znx_mut(),
j,
&sk_tensor.at(i, j).as_vec_znx(),
0,
);
} }
let ksk: &GGLWE<Vec<u8>> = key.at(i); let ksk: &GGLWE<Vec<u8>> = key.at(i);
@@ -127,9 +116,8 @@ where
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(GGLWEToGGSWKeyCompressed::encrypt_sk_tmp_bytes( let mut scratch: ScratchOwned<BE> =
module, &key_infos, ScratchOwned::alloc(GGLWEToGGSWKeyCompressed::encrypt_sk_tmp_bytes(module, &key_infos));
));
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&key_infos); let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&key_infos);
sk.fill_ternary_prob(0.5, &mut source_xs); sk.fill_ternary_prob(0.5, &mut source_xs);
@@ -152,12 +140,7 @@ where
for i in 0..rank { for i in 0..rank {
for j in 0..rank { for j in 0..rank {
module.vec_znx_copy( module.vec_znx_copy(&mut pt_want.as_vec_znx_mut(), j, &sk_tensor.at(i, j).as_vec_znx(), 0);
&mut pt_want.as_vec_znx_mut(),
j,
&sk_tensor.at(i, j).as_vec_znx(),
0,
);
} }
let ksk: &GGLWE<Vec<u8>> = key.at(i); let ksk: &GGLWE<Vec<u8>> = key.at(i);

View File

@@ -121,14 +121,7 @@ where
let seed_xa: [u8; 32] = [1u8; 32]; let seed_xa: [u8; 32] = [1u8; 32];
ct_compressed.encrypt_sk( ct_compressed.encrypt_sk(module, &pt_scalar, &sk_prepared, seed_xa, &mut source_xe, scratch.borrow());
module,
&pt_scalar,
&sk_prepared,
seed_xa,
&mut source_xe,
scratch.borrow(),
);
let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5;

View File

@@ -68,10 +68,7 @@ where
scratch.borrow(), scratch.borrow(),
); );
let noise_have: f64 = ct let noise_have: f64 = ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2();
.noise(module, &pt_want, &sk_prepared, scratch.borrow())
.std()
.log2();
let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5; let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5;
assert!(noise_have <= noise_want + 0.2); assert!(noise_have <= noise_want + 0.2);
@@ -126,22 +123,12 @@ where
let seed_xa: [u8; 32] = [1u8; 32]; let seed_xa: [u8; 32] = [1u8; 32];
ct_compressed.encrypt_sk( ct_compressed.encrypt_sk(module, &pt_want, &sk_prepared, seed_xa, &mut source_xe, scratch.borrow());
module,
&pt_want,
&sk_prepared,
seed_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ct: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos); let mut ct: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos);
ct.decompress(module, &ct_compressed); ct.decompress(module, &ct_compressed);
let noise_have: f64 = ct let noise_have: f64 = ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2();
.noise(module, &pt_want, &sk_prepared, scratch.borrow())
.std()
.log2();
let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5; let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5;
assert!(noise_have <= noise_want + 0.2); assert!(noise_have <= noise_want + 0.2);
} }
@@ -186,18 +173,9 @@ where
let mut ct: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos); let mut ct: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos);
ct.encrypt_zero_sk( ct.encrypt_zero_sk(module, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&sk_prepared,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let noise_have: f64 = ct let noise_have: f64 = ct.noise(module, &pt, &sk_prepared, scratch.borrow()).std().log2();
.noise(module, &pt, &sk_prepared, scratch.borrow())
.std()
.log2();
let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5; let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5;
assert!(noise_have <= noise_want + 0.2); assert!(noise_have <= noise_want + 0.2);
} }
@@ -265,10 +243,7 @@ where
scratch.borrow(), scratch.borrow(),
); );
let noise_have: f64 = ct let noise_have: f64 = ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2();
.noise(module, &pt_want, &sk_prepared, scratch.borrow())
.std()
.log2();
let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64); let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64);
assert!(noise_have <= noise_want + 0.2); assert!(noise_have <= noise_want + 0.2);
} }

View File

@@ -46,23 +46,14 @@ where
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(GLWETensorKey::encrypt_sk_tmp_bytes( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key_infos));
module,
&tensor_key_infos,
));
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&tensor_key_infos); let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&tensor_key_infos);
sk.fill_ternary_prob(0.5, &mut source_xs); sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into()); let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into());
sk_prepared.prepare(module, &sk); sk_prepared.prepare(module, &sk);
tensor_key.encrypt_sk( tensor_key.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut sk_tensor: GLWESecretTensor<Vec<u8>> = GLWESecretTensor::alloc_from_infos(&sk); let mut sk_tensor: GLWESecretTensor<Vec<u8>> = GLWESecretTensor::alloc_from_infos(&sk);
sk_tensor.prepare(module, &sk, scratch.borrow()); sk_tensor.prepare(module, &sk, scratch.borrow());
@@ -74,14 +65,7 @@ where
assert!( assert!(
tensor_key tensor_key
.0 .0
.noise( .noise(module, row, col, &sk_tensor.data, &sk_prepared, scratch.borrow())
module,
row,
col,
&sk_tensor.data,
&sk_prepared,
scratch.borrow()
)
.std() .std()
.log2() .log2()
<= max_noise <= max_noise
@@ -124,10 +108,8 @@ where
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(GLWETensorKeyCompressed::encrypt_sk_tmp_bytes( let mut scratch: ScratchOwned<BE> =
module, ScratchOwned::alloc(GLWETensorKeyCompressed::encrypt_sk_tmp_bytes(module, &tensor_key_infos));
&tensor_key_infos,
));
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&tensor_key_infos); let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&tensor_key_infos);
sk.fill_ternary_prob(0.5, &mut source_xs); sk.fill_ternary_prob(0.5, &mut source_xs);
@@ -151,14 +133,7 @@ where
assert!( assert!(
tensor_key tensor_key
.0 .0
.noise( .noise(module, row, col, &sk_tensor.data, &sk_prepared, scratch.borrow())
module,
row,
col,
&sk_tensor.data,
&sk_prepared,
scratch.borrow()
)
.std() .std()
.log2() .log2()
<= max_noise <= max_noise

View File

@@ -28,26 +28,26 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let base2k_out: usize = base2k_in; // MUST BE SAME let out_base2k: usize = in_base2k; // MUST BE SAME
let k_in: usize = 102; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key); let max_dsize: usize = k_in.div_ceil(key_base2k);
for rank_in in 1_usize..3 { for rank_in in 1_usize..3 {
for rank_out in 1_usize..3 { for rank_out in 1_usize..3 {
for dsize in 1_usize..max_dsize + 1 { for dsize in 1_usize..max_dsize + 1 {
let k_ggsw: usize = k_in + base2k_key * dsize; let k_ggsw: usize = k_in + key_base2k * dsize;
let k_out: usize = k_in; // Better capture noise. let k_out: usize = k_in; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dnum_in: usize = k_in / base2k_in; let dnum_in: usize = k_in / in_base2k;
let dnum: usize = k_in.div_ceil(base2k_key * dsize); let dnum: usize = k_in.div_ceil(key_base2k * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let gglwe_in_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_in_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -57,7 +57,7 @@ where
let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -67,7 +67,7 @@ where
let ggsw_infos: GGSWLayout = GGSWLayout { let ggsw_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ggsw.into(), k: k_ggsw.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -106,14 +106,7 @@ where
sk_out_prepared.prepare(module, &sk_out); sk_out_prepared.prepare(module, &sk_out);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
ct_gglwe_in.encrypt_sk( ct_gglwe_in.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&sk_in,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct_rgsw.encrypt_sk( ct_rgsw.encrypt_sk(
module, module,
@@ -131,12 +124,7 @@ where
ct_gglwe_out.external_product(module, &ct_gglwe_in, &ct_rgsw_prepared, scratch.borrow()); ct_gglwe_out.external_product(module, &ct_gglwe_in, &ct_rgsw_prepared, scratch.borrow());
(0..rank_in).for_each(|i| { (0..rank_in).for_each(|i| {
module.vec_znx_rotate_inplace( module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data.as_vec_znx_mut(), i, scratch.borrow()); // * X^{r}
r as i64,
&mut sk_in.data.as_vec_znx_mut(),
i,
scratch.borrow(),
); // * X^{r}
}); });
let var_gct_err_lhs: f64 = SIGMA * SIGMA; let var_gct_err_lhs: f64 = SIGMA * SIGMA;
@@ -148,7 +136,7 @@ where
let max_noise: f64 = noise_ggsw_product( let max_noise: f64 = noise_ggsw_product(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
var_xs, var_xs,
var_msg, var_msg,
var_a0_err, var_a0_err,
@@ -165,14 +153,7 @@ where
assert!( assert!(
ct_gglwe_out ct_gglwe_out
.key .key
.noise( .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow())
module,
row,
col,
&sk_in.data,
&sk_out_prepared,
scratch.borrow()
)
.std() .std()
.log2() .log2()
<= max_noise + 0.5 <= max_noise + 0.5
@@ -197,25 +178,25 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 17; let out_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let k_out: usize = 102; let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key); let max_dsize: usize = k_out.div_ceil(key_base2k);
for rank_in in 1_usize..3 { for rank_in in 1_usize..3 {
for rank_out in 1_usize..3 { for rank_out in 1_usize..3 {
for dsize in 1_usize..max_dsize + 1 { for dsize in 1_usize..max_dsize + 1 {
let k_ggsw: usize = k_out + base2k_key * dsize; let k_ggsw: usize = k_out + key_base2k * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dnum_in: usize = k_out / base2k_out; let dnum_in: usize = k_out / out_base2k;
let dnum: usize = k_out.div_ceil(base2k_key * dsize); let dnum: usize = k_out.div_ceil(key_base2k * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -225,7 +206,7 @@ where
let ggsw_infos: GGSWLayout = GGSWLayout { let ggsw_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ggsw.into(), k: k_ggsw.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -263,14 +244,7 @@ where
sk_out_prepared.prepare(module, &sk_out); sk_out_prepared.prepare(module, &sk_out);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
ct_gglwe.encrypt_sk( ct_gglwe.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&sk_in,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
ct_rgsw.encrypt_sk( ct_rgsw.encrypt_sk(
module, module,
@@ -288,12 +262,7 @@ where
ct_gglwe.external_product_inplace(module, &ct_rgsw_prepared, scratch.borrow()); ct_gglwe.external_product_inplace(module, &ct_rgsw_prepared, scratch.borrow());
(0..rank_in).for_each(|i| { (0..rank_in).for_each(|i| {
module.vec_znx_rotate_inplace( module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data.as_vec_znx_mut(), i, scratch.borrow()); // * X^{r}
r as i64,
&mut sk_in.data.as_vec_znx_mut(),
i,
scratch.borrow(),
); // * X^{r}
}); });
let var_gct_err_lhs: f64 = SIGMA * SIGMA; let var_gct_err_lhs: f64 = SIGMA * SIGMA;
@@ -305,7 +274,7 @@ where
let max_noise: f64 = noise_ggsw_product( let max_noise: f64 = noise_ggsw_product(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
var_xs, var_xs,
var_msg, var_msg,
var_a0_err, var_a0_err,
@@ -322,14 +291,7 @@ where
assert!( assert!(
ct_gglwe ct_gglwe
.key .key
.noise( .noise(module, row, col, &sk_in.data, &sk_out_prepared, scratch.borrow())
module,
row,
col,
&sk_in.data,
&sk_out_prepared,
scratch.borrow()
)
.std() .std()
.log2() .log2()
<= max_noise + 0.5 <= max_noise + 0.5

View File

@@ -26,26 +26,26 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let base2k_out: usize = base2k_in; // MUST BE SAME let out_base2k: usize = in_base2k; // MUST BE SAME
let k_in: usize = 102; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key); let max_dsize: usize = k_in.div_ceil(key_base2k);
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_apply: usize = k_in + base2k_key * dsize; let k_apply: usize = k_in + key_base2k * dsize;
let k_out: usize = k_in; // Better capture noise. let k_out: usize = k_in; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_in.div_ceil(base2k_key * dsize); let dnum: usize = k_in.div_ceil(key_base2k * dsize);
let dnum_in: usize = k_in / base2k_in; let dnum_in: usize = k_in / in_base2k;
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_in_infos: GGSWLayout = GGSWLayout { let ggsw_in_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -54,7 +54,7 @@ where
let ggsw_out_infos: GGSWLayout = GGSWLayout { let ggsw_out_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -63,7 +63,7 @@ where
let ggsw_apply_infos: GGSWLayout = GGSWLayout { let ggsw_apply_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_apply.into(), k: k_apply.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -107,14 +107,7 @@ where
scratch.borrow(), scratch.borrow(),
); );
ggsw_in.encrypt_sk( ggsw_in.encrypt_sk(module, &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&pt_in,
&sk_prepared,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ct_rhs_prepared: GGSWPrepared<Vec<u8>, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); let mut ct_rhs_prepared: GGSWPrepared<Vec<u8>, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply);
ct_rhs_prepared.prepare(module, &ggsw_apply, scratch.borrow()); ct_rhs_prepared.prepare(module, &ggsw_apply, scratch.borrow());
@@ -133,7 +126,7 @@ where
let max_noise = |_col_j: usize| -> f64 { let max_noise = |_col_j: usize| -> f64 {
noise_ggsw_product( noise_ggsw_product(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
0.5, 0.5,
var_msg, var_msg,
var_a0_err, var_a0_err,
@@ -173,23 +166,23 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 17; let out_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let k_out: usize = 102; let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key); let max_dsize: usize = k_out.div_ceil(key_base2k);
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_apply: usize = k_out + base2k_key * dsize; let k_apply: usize = k_out + key_base2k * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_out.div_ceil(dsize * base2k_key); let dnum: usize = k_out.div_ceil(dsize * key_base2k);
let dnum_in: usize = k_out / base2k_out; let dnum_in: usize = k_out / out_base2k;
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_out_infos: GGSWLayout = GGSWLayout { let ggsw_out_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -198,7 +191,7 @@ where
let ggsw_apply_infos: GGSWLayout = GGSWLayout { let ggsw_apply_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_apply.into(), k: k_apply.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -242,14 +235,7 @@ where
scratch.borrow(), scratch.borrow(),
); );
ggsw_out.encrypt_sk( ggsw_out.encrypt_sk(module, &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&pt_in,
&sk_prepared,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ct_rhs_prepared: GGSWPrepared<Vec<u8>, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); let mut ct_rhs_prepared: GGSWPrepared<Vec<u8>, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply);
ct_rhs_prepared.prepare(module, &ggsw_apply, scratch.borrow()); ct_rhs_prepared.prepare(module, &ggsw_apply, scratch.borrow());
@@ -268,7 +254,7 @@ where
let max_noise = |_col_j: usize| -> f64 { let max_noise = |_col_j: usize| -> f64 {
noise_ggsw_product( noise_ggsw_product(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
0.5, 0.5,
var_msg, var_msg,
var_a0_err, var_a0_err,

View File

@@ -29,14 +29,14 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let base2k_out: usize = 15; let out_base2k: usize = 15;
let k_in: usize = 102; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key); let max_dsize: usize = k_in.div_ceil(key_base2k);
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ggsw: usize = k_in + base2k_key * dsize; let k_ggsw: usize = k_in + key_base2k * dsize;
let k_out: usize = k_ggsw; // Better capture noise let k_out: usize = k_ggsw; // Better capture noise
let n: usize = module.n(); let n: usize = module.n();
@@ -44,21 +44,21 @@ where
let glwe_in_infos: GLWELayout = GLWELayout { let glwe_in_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_in.into(), k: k_in.into(),
rank: rank.into(), rank: rank.into(),
}; };
let glwe_out_infos: GLWELayout = GLWELayout { let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
rank: rank.into(), rank: rank.into(),
}; };
let ggsw_apply_infos: GGSWLayout = GGSWLayout { let ggsw_apply_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ggsw.into(), k: k_ggsw.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -77,7 +77,7 @@ where
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext // Random input plaintext
module.vec_znx_fill_uniform(base2k_in, &mut pt_in.data, 0, &mut source_xa); module.vec_znx_fill_uniform(in_base2k, &mut pt_in.data, 0, &mut source_xa);
pt_in.data.at_mut(0, 0)[1] = 1; pt_in.data.at_mut(0, 0)[1] = 1;
@@ -106,14 +106,7 @@ where
scratch.borrow(), scratch.borrow(),
); );
glwe_in.encrypt_sk( glwe_in.encrypt_sk(module, &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&pt_in,
&sk_prepared,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ct_ggsw_prepared: GGSWPrepared<Vec<u8>, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply); let mut ct_ggsw_prepared: GGSWPrepared<Vec<u8>, BE> = GGSWPrepared::alloc_from_infos(module, &ggsw_apply);
ct_ggsw_prepared.prepare(module, &ggsw_apply, scratch.borrow()); ct_ggsw_prepared.prepare(module, &ggsw_apply, scratch.borrow());
@@ -133,7 +126,7 @@ where
let max_noise: f64 = noise_ggsw_product( let max_noise: f64 = noise_ggsw_product(
n as f64, n as f64,
base2k_key * max_dsize, key_base2k * max_dsize,
0.5, 0.5,
var_msg, var_msg,
var_a0_err, var_a0_err,
@@ -145,13 +138,7 @@ where
k_ggsw, k_ggsw,
); );
assert!( assert!(glwe_out.noise(module, &pt_out, &sk_prepared, scratch.borrow()).std().log2() <= max_noise + 1.0)
glwe_out
.noise(module, &pt_out, &sk_prepared, scratch.borrow())
.std()
.log2()
<= max_noise + 1.0
)
} }
} }
} }
@@ -170,28 +157,28 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 17; let out_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let k_out: usize = 102; let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key); let max_dsize: usize = k_out.div_ceil(key_base2k);
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ggsw: usize = k_out + base2k_key * dsize; let k_ggsw: usize = k_out + key_base2k * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_out.div_ceil(base2k_out * max_dsize); let dnum: usize = k_out.div_ceil(out_base2k * max_dsize);
let glwe_out_infos: GLWELayout = GLWELayout { let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
rank: rank.into(), rank: rank.into(),
}; };
let ggsw_apply_infos: GGSWLayout = GGSWLayout { let ggsw_apply_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ggsw.into(), k: k_ggsw.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -208,7 +195,7 @@ where
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext // Random input plaintext
module.vec_znx_fill_uniform(base2k_out, &mut pt_want.data, 0, &mut source_xa); module.vec_znx_fill_uniform(out_base2k, &mut pt_want.data, 0, &mut source_xa);
pt_want.data.at_mut(0, 0)[1] = 1; pt_want.data.at_mut(0, 0)[1] = 1;
@@ -262,7 +249,7 @@ where
let max_noise: f64 = noise_ggsw_product( let max_noise: f64 = noise_ggsw_product(
n as f64, n as f64,
base2k_key * max_dsize, key_base2k * max_dsize,
0.5, 0.5,
var_msg, var_msg,
var_a0_err, var_a0_err,
@@ -274,13 +261,7 @@ where
k_ggsw, k_ggsw,
); );
assert!( assert!(glwe_out.noise(module, &pt_want, &sk_prepared, scratch.borrow()).std().log2() <= max_noise + 1.0)
glwe_out
.noise(module, &pt_want, &sk_prepared, scratch.borrow())
.std()
.log2()
<= max_noise + 1.0
)
} }
} }
} }

View File

@@ -33,26 +33,26 @@ where
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
let n: usize = module.n(); let n: usize = module.n();
let base2k_out: usize = 15; let out_base2k: usize = 15;
let base2k_key: usize = 10; let key_base2k: usize = 10;
let k_ct: usize = 36; let k_ct: usize = 36;
let pt_k: usize = 18; let pt_k: usize = 18;
let rank: usize = 3; let rank: usize = 3;
let dsize: usize = 1; let dsize: usize = 1;
let k_ksk: usize = k_ct + base2k_key * dsize; let k_ksk: usize = k_ct + key_base2k * dsize;
let dnum: usize = k_ct.div_ceil(base2k_key * dsize); let dnum: usize = k_ct.div_ceil(key_base2k * dsize);
let glwe_out_infos: GLWELayout = GLWELayout { let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_ct.into(), k: k_ct.into(),
rank: rank.into(), rank: rank.into(),
}; };
let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
rank: rank.into(), rank: rank.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -84,14 +84,7 @@ where
let mut auto_keys: HashMap<i64, GLWEAutomorphismKeyPrepared<Vec<u8>, BE>> = HashMap::new(); let mut auto_keys: HashMap<i64, GLWEAutomorphismKeyPrepared<Vec<u8>, BE>> = HashMap::new();
let mut tmp: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&key_infos); let mut tmp: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&key_infos);
gal_els.iter().for_each(|gal_el| { gal_els.iter().for_each(|gal_el| {
tmp.encrypt_sk( tmp.encrypt_sk(module, *gal_el, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
*gal_el,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut atk_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> = let mut atk_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> =
GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp); GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp);
atk_prepared.prepare(module, &tmp, scratch.borrow()); atk_prepared.prepare(module, &tmp, scratch.borrow());
@@ -104,26 +97,12 @@ where
let mut ct: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_out_infos); let mut ct: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_out_infos);
ct.encrypt_sk( ct.encrypt_sk(module, &pt, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&pt,
&sk_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let log_n: usize = module.log_n(); let log_n: usize = module.log_n();
(0..n >> log_batch).for_each(|i| { (0..n >> log_batch).for_each(|i| {
ct.encrypt_sk( ct.encrypt_sk(module, &pt, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&pt,
&sk_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
module.glwe_rotate_inplace(-(1 << log_batch), &mut pt, scratch.borrow()); // X^-batch * pt module.glwe_rotate_inplace(-(1 << log_batch), &mut pt, scratch.borrow()); // X^-batch * pt
@@ -153,10 +132,7 @@ where
let noise_have: f64 = pt.stats().std().log2(); let noise_have: f64 = pt.stats().std().log2();
assert!( assert!(noise_have < -((k_ct - out_base2k) as f64), "noise: {noise_have}");
noise_have < -((k_ct - base2k_out) as f64),
"noise: {noise_have}"
);
} }
#[inline(always)] #[inline(always)]

View File

@@ -35,26 +35,26 @@ where
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
let n: usize = module.n(); let n: usize = module.n();
let base2k_out: usize = 15; let out_base2k: usize = 15;
let base2k_key: usize = 10; let key_base2k: usize = 10;
let k_ct: usize = 36; let k_ct: usize = 36;
let pt_k: usize = base2k_out; let pt_k: usize = out_base2k;
let rank: usize = 3; let rank: usize = 3;
let dsize: usize = 1; let dsize: usize = 1;
let k_ksk: usize = k_ct + base2k_key * dsize; let k_ksk: usize = k_ct + key_base2k * dsize;
let dnum: usize = k_ct.div_ceil(base2k_key * dsize); let dnum: usize = k_ct.div_ceil(key_base2k * dsize);
let glwe_out_infos: GLWELayout = GLWELayout { let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_ct.into(), k: k_ct.into(),
rank: rank.into(), rank: rank.into(),
}; };
let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
rank: rank.into(), rank: rank.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -63,9 +63,7 @@ where
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos)
.max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &key_infos))
module, &key_infos,
))
.max(module.glwe_pack_tmp_bytes(&glwe_out_infos, &key_infos)), .max(module.glwe_pack_tmp_bytes(&glwe_out_infos, &key_infos)),
); );
@@ -88,14 +86,7 @@ where
let mut auto_keys: HashMap<i64, GLWEAutomorphismKeyPrepared<Vec<u8>, BE>> = HashMap::new(); let mut auto_keys: HashMap<i64, GLWEAutomorphismKeyPrepared<Vec<u8>, BE>> = HashMap::new();
let mut tmp: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&key_infos); let mut tmp: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&key_infos);
gal_els.iter().for_each(|gal_el| { gal_els.iter().for_each(|gal_el| {
tmp.encrypt_sk( tmp.encrypt_sk(module, *gal_el, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
*gal_el,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut atk_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> = let mut atk_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> =
GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp); GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp);
atk_prepared.prepare(module, &tmp, scratch.borrow()); atk_prepared.prepare(module, &tmp, scratch.borrow());
@@ -106,14 +97,7 @@ where
.step_by(5) .step_by(5)
.map(|_| { .map(|_| {
let mut ct = GLWE::alloc_from_infos(&glwe_out_infos); let mut ct = GLWE::alloc_from_infos(&glwe_out_infos);
ct.encrypt_sk( ct.encrypt_sk(module, &pt, &sk_prep, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&pt,
&sk_prep,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
module.glwe_rotate_inplace(-5, &mut pt, scratch.borrow()); // X^-batch * pt module.glwe_rotate_inplace(-5, &mut pt, scratch.borrow()); // X^-batch * pt
ct ct
}) })
@@ -139,10 +123,5 @@ where
pt_want.encode_vec_i64(&data, pt_k.into()); pt_want.encode_vec_i64(&data, pt_k.into());
assert!( assert!(res.noise(module, &pt_want, &sk_prep, scratch.borrow()).std().log2() <= ((k_ct - out_base2k) as f64));
res.noise(module, &pt_want, &sk_prep, scratch.borrow())
.std()
.log2()
<= ((k_ct - base2k_out) as f64)
);
} }

View File

@@ -0,0 +1,390 @@
use poulpy_hal::{
api::{
ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy, VecZnxFillUniform, VecZnxNormalize,
VecZnxNormalizeInplace,
},
layouts::{Backend, FillUniform, Module, Scratch, ScratchOwned, VecZnx, ZnxViewMut},
source::Source,
test_suite::convolution::bivariate_convolution_naive,
};
use rand::RngCore;
use std::f64::consts::SQRT_2;
use crate::{
GLWEDecrypt, GLWEEncryptSk, GLWEMulConst, GLWEMulPlain, GLWESub, GLWETensorKeyEncryptSk, GLWETensoring, ScratchTakeCore,
layouts::{
Dsize, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory,
GLWESecretTensorPrepared, GLWETensor, GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPrepared,
GLWETensorKeyPreparedFactory, LWEInfos, TorusPrecision, prepared::GLWESecretPrepared,
},
};
pub fn test_glwe_tensoring<BE: Backend>(module: &Module<BE>)
where
Module<BE>: GLWETensoring<BE>
+ GLWEEncryptSk<BE>
+ GLWEDecrypt<BE>
+ VecZnxFillUniform
+ GLWESecretPreparedFactory<BE>
+ GLWESub
+ VecZnxNormalizeInplace<BE>
+ GLWESecretTensorFactory<BE>
+ VecZnxCopy
+ VecZnxNormalize<BE>
+ GLWETensorKeyEncryptSk<BE>
+ GLWETensorKeyPreparedFactory<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{
let in_base2k: usize = 16;
let out_base2k: usize = 13;
let tsk_base2k: usize = 15;
let k: usize = 128;
for rank in 1_usize..=3 {
let n: usize = module.n();
let glwe_in_infos: GLWELayout = GLWELayout {
n: n.into(),
base2k: in_base2k.into(),
k: k.into(),
rank: rank.into(),
};
let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(),
base2k: out_base2k.into(),
k: k.into(),
rank: rank.into(),
};
let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout {
n: n.into(),
base2k: tsk_base2k.into(),
k: (k + tsk_base2k).into(),
rank: rank.into(),
dnum: k.div_ceil(tsk_base2k).into(),
dsize: Dsize(1),
};
let mut a: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_in_infos);
let mut b: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_in_infos);
let mut res_tensor: GLWETensor<Vec<u8>> = GLWETensor::alloc_from_infos(&glwe_out_infos);
let mut res_relin: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_out_infos);
let mut pt_in: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos);
let mut pt_tmp: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos)
.max(GLWE::decrypt_tmp_bytes(module, &glwe_out_infos))
.max(module.glwe_tensor_apply_tmp_bytes(&res_tensor, 0, &a, &b))
.max(module.glwe_secret_tensor_prepare_tmp_bytes(rank.into()))
.max(module.glwe_tensor_relinearize_tmp_bytes(&res_relin, &res_tensor, &tsk_infos)),
);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc(module.n().into(), rank.into());
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk);
sk_dft.prepare(module, &sk);
let mut sk_tensor: GLWESecretTensor<Vec<u8>> = GLWESecretTensor::alloc(module.n().into(), rank.into());
sk_tensor.prepare(module, &sk, scratch.borrow());
let mut sk_tensor_prep: GLWESecretTensorPrepared<Vec<u8>, BE> = GLWESecretTensorPrepared::alloc(module, rank.into());
sk_tensor_prep.prepare(module, &sk_tensor);
let mut tsk: GLWETensorKey<Vec<u8>> = GLWETensorKey::alloc_from_infos(&tsk_infos);
tsk.encrypt_sk(module, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
let mut tsk_prep: GLWETensorKeyPrepared<Vec<u8>, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk_infos);
tsk_prep.prepare(module, &tsk, scratch.borrow());
let scale: usize = 2 * in_base2k;
let mut data = vec![0i64; n];
for i in data.iter_mut() {
*i = (source_xa.next_i64() & 7) - 4;
}
pt_in.encode_vec_i64(&data, TorusPrecision(scale as u32));
let mut pt_want_base2k_in = VecZnx::alloc(n, 1, pt_in.size());
bivariate_convolution_naive(
module,
in_base2k,
2,
&mut pt_want_base2k_in,
0,
pt_in.data(),
0,
pt_in.data(),
0,
scratch.borrow(),
);
a.encrypt_sk(module, &pt_in, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
b.encrypt_sk(module, &pt_in, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
for res_offset in 0..scale {
module.glwe_tensor_apply(&mut res_tensor, scale + res_offset, &a, &b, scratch.borrow());
res_tensor.decrypt(module, &mut pt_have, &sk_dft, &sk_tensor_prep, scratch.borrow());
module.vec_znx_normalize(
pt_want.data_mut(),
out_base2k,
res_offset as i64,
0,
&pt_want_base2k_in,
in_base2k,
0,
scratch.borrow(),
);
module.glwe_sub(&mut pt_tmp, &pt_have, &pt_want);
module.vec_znx_normalize_inplace(pt_tmp.base2k().as_usize(), &mut pt_tmp.data, 0, scratch.borrow());
let noise_have: f64 = pt_tmp.stats().std().log2();
let noise_want = -((k - scale - res_offset - module.log_n()) as f64 - ((rank - 1) as f64) / SQRT_2);
assert!(noise_have - noise_want <= 0.5, "{} > {}", noise_have, noise_want);
module.glwe_tensor_relinearize(&mut res_relin, &res_tensor, &tsk_prep, tsk_prep.size(), scratch.borrow());
res_relin.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow());
module.glwe_sub(&mut pt_tmp, &pt_have, &pt_want);
module.vec_znx_normalize_inplace(pt_tmp.base2k().as_usize(), &mut pt_tmp.data, 0, scratch.borrow());
// We can reuse the same noise bound because the relinearization noise (which is additive)
// is much smaller than the tensoring noise (which is multiplicative)
let noise_have: f64 = pt_tmp.stats().std().log2();
assert!(noise_have - noise_want <= 0.5, "{} > {}", noise_have, noise_want);
}
}
}
pub fn test_glwe_mul_plain<BE: Backend>(module: &Module<BE>)
where
Module<BE>: GLWEEncryptSk<BE>
+ GLWEDecrypt<BE>
+ VecZnxFillUniform
+ GLWESecretPreparedFactory<BE>
+ GLWESub
+ VecZnxNormalizeInplace<BE>
+ VecZnxCopy
+ VecZnxNormalize<BE>
+ GLWEMulPlain<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{
let in_base2k: usize = 16;
let out_base2k: usize = 13;
let k: usize = 128;
for rank in 1_usize..=3 {
let n: usize = module.n();
let glwe_in_infos: GLWELayout = GLWELayout {
n: n.into(),
base2k: in_base2k.into(),
k: k.into(),
rank: rank.into(),
};
let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(),
base2k: out_base2k.into(),
k: k.into(),
rank: rank.into(),
};
let mut a: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_in_infos);
let mut res: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_out_infos);
let mut pt_a: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos);
let mut pt_b: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(module.n().into(), in_base2k.into(), (2 * in_base2k).into());
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos);
let mut pt_tmp: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos)
.max(GLWE::decrypt_tmp_bytes(module, &glwe_out_infos))
.max(module.glwe_mul_plain_tmp_bytes(&res, 0, &a, &pt_b)),
);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc(module.n().into(), rank.into());
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk);
sk_dft.prepare(module, &sk);
let scale: usize = 2 * in_base2k;
pt_b.data_mut().fill_uniform(in_base2k, &mut source_xa);
pt_a.data_mut().fill_uniform(in_base2k, &mut source_xa);
let mut pt_want_base2k_in = VecZnx::alloc(n, 1, pt_a.size() + pt_b.size());
bivariate_convolution_naive(
module,
in_base2k,
2,
&mut pt_want_base2k_in,
0,
pt_a.data(),
0,
pt_b.data(),
0,
scratch.borrow(),
);
a.encrypt_sk(module, &pt_a, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
for res_offset in 0..scale {
module.glwe_mul_plain(&mut res, scale + res_offset, &a, &pt_b, scratch.borrow());
res.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_normalize(
pt_want.data_mut(),
out_base2k,
res_offset as i64,
0,
&pt_want_base2k_in,
in_base2k,
0,
scratch.borrow(),
);
module.glwe_sub(&mut pt_tmp, &pt_have, &pt_want);
module.vec_znx_normalize_inplace(pt_tmp.base2k().as_usize(), &mut pt_tmp.data, 0, scratch.borrow());
let noise_have: f64 = pt_tmp.stats().std().log2();
let noise_want = -((k - scale - res_offset - module.log_n()) as f64 - ((rank - 1) as f64) / SQRT_2);
assert!(noise_have - noise_want <= 0.5, "{} > {}", noise_have, noise_want);
}
}
}
pub fn test_glwe_mul_const<BE: Backend>(module: &Module<BE>)
where
Module<BE>: GLWEEncryptSk<BE>
+ GLWEDecrypt<BE>
+ VecZnxFillUniform
+ GLWESecretPreparedFactory<BE>
+ GLWESub
+ VecZnxNormalizeInplace<BE>
+ VecZnxCopy
+ VecZnxNormalize<BE>
+ GLWEMulConst<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{
let in_base2k: usize = 16;
let out_base2k: usize = 13;
let k: usize = 128;
let b_size = 3;
for rank in 1_usize..=3 {
let n: usize = module.n();
let glwe_in_infos: GLWELayout = GLWELayout {
n: n.into(),
base2k: in_base2k.into(),
k: k.into(),
rank: rank.into(),
};
let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(),
base2k: out_base2k.into(),
k: k.into(),
rank: rank.into(),
};
let mut a: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_in_infos);
let mut res: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_out_infos);
let mut pt_a: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos);
let mut pt_b: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(module.n().into(), in_base2k.into(), (2 * in_base2k).into());
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos);
let mut pt_tmp: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWE::encrypt_sk_tmp_bytes(module, &glwe_in_infos)
.max(GLWE::decrypt_tmp_bytes(module, &glwe_out_infos))
.max(module.glwe_mul_const_tmp_bytes(&res, 0, &a, b_size)),
);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc(module.n().into(), rank.into());
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk);
sk_dft.prepare(module, &sk);
let scale: usize = 2 * in_base2k;
pt_a.data_mut().fill_uniform(in_base2k, &mut source_xa);
let mut b_const = vec![0i64; b_size];
let mask = (1 << in_base2k) - 1;
for (j, x) in b_const[..1].iter_mut().enumerate() {
let r = source_xa.next_u64() & mask;
*x = ((r << (64 - in_base2k)) as i64) >> (64 - in_base2k);
pt_b.data_mut().at_mut(0, j)[0] = *x
}
let mut pt_want_base2k_in = VecZnx::alloc(n, 1, pt_a.size() + pt_b.size());
bivariate_convolution_naive(
module,
in_base2k,
2,
&mut pt_want_base2k_in,
0,
pt_a.data(),
0,
pt_b.data(),
0,
scratch.borrow(),
);
a.encrypt_sk(module, &pt_a, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
for res_offset in 0..scale {
module.glwe_mul_const(&mut res, scale + res_offset, &a, &b_const, scratch.borrow());
res.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_normalize(
pt_want.data_mut(),
out_base2k,
res_offset as i64,
0,
&pt_want_base2k_in,
in_base2k,
0,
scratch.borrow(),
);
module.glwe_sub(&mut pt_tmp, &pt_have, &pt_want);
module.vec_znx_normalize_inplace(pt_tmp.base2k().as_usize(), &mut pt_tmp.data, 0, scratch.borrow());
let noise_have: f64 = pt_tmp.stats().std().log2();
let noise_want = -((k - scale - res_offset - module.log_n()) as f64 - ((rank - 1) as f64) / SQRT_2);
assert!(noise_have - noise_want <= 0.5, "{} > {}", noise_have, noise_want);
}
}
}

View File

@@ -26,27 +26,27 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let base2k_out: usize = base2k_in; // MUST BE SAME let out_base2k: usize = in_base2k; // MUST BE SAME
let k_in: usize = 102; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key); let max_dsize: usize = k_in.div_ceil(key_base2k);
for rank_in_s0s1 in 1_usize..2 { for rank_in_s0s1 in 1_usize..2 {
for rank_out_s0s1 in 1_usize..3 { for rank_out_s0s1 in 1_usize..3 {
for rank_out_s1s2 in 1_usize..3 { for rank_out_s1s2 in 1_usize..3 {
for dsize in 1_usize..max_dsize + 1 { for dsize in 1_usize..max_dsize + 1 {
let k_ksk: usize = k_in + base2k_key * dsize; let k_ksk: usize = k_in + key_base2k * dsize;
let k_out: usize = k_ksk; // Better capture noise. let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dsize_in: usize = 1; let dsize_in: usize = 1;
let dnum_in: usize = k_in / base2k_in; let dnum_in: usize = k_in / in_base2k;
let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let dnum_ksk: usize = k_in.div_ceil(key_base2k * dsize);
let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -56,7 +56,7 @@ where
let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -66,7 +66,7 @@ where
let gglwe_s0s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s0s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -108,43 +108,24 @@ where
sk2_prepared.prepare(module, &sk2); sk2_prepared.prepare(module, &sk2);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
gglwe_s0s1.encrypt_sk( gglwe_s0s1.encrypt_sk(module, &sk0, &sk1, &mut source_xa, &mut source_xe, scratch_enc.borrow());
module,
&sk0,
&sk1,
&mut source_xa,
&mut source_xe,
scratch_enc.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2 // gglwe_{s2}(s1) -> s1 -> s2
gglwe_s1s2.encrypt_sk( gglwe_s1s2.encrypt_sk(module, &sk1, &sk2, &mut source_xa, &mut source_xe, scratch_enc.borrow());
module,
&sk1,
&sk2,
&mut source_xa,
&mut source_xe,
scratch_enc.borrow(),
);
let mut gglwe_s1s2_prepared: GLWESwitchingKeyPrepared<Vec<u8>, BE> = let mut gglwe_s1s2_prepared: GLWESwitchingKeyPrepared<Vec<u8>, BE> =
GLWESwitchingKeyPrepared::alloc_from_infos(module, &gglwe_s1s2); GLWESwitchingKeyPrepared::alloc_from_infos(module, &gglwe_s1s2);
gglwe_s1s2_prepared.prepare(module, &gglwe_s1s2, scratch_apply.borrow()); gglwe_s1s2_prepared.prepare(module, &gglwe_s1s2, scratch_apply.borrow());
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
gglwe_s0s2.keyswitch( gglwe_s0s2.keyswitch(module, &gglwe_s0s1, &gglwe_s1s2_prepared, scratch_apply.borrow());
module,
&gglwe_s0s1,
&gglwe_s1s2_prepared,
scratch_apply.borrow(),
);
let max_noise: f64 = var_noise_gglwe_product_v2( let max_noise: f64 = var_noise_gglwe_product_v2(
module.n() as f64, module.n() as f64,
k_ksk, k_ksk,
dnum_ksk, dnum_ksk,
dsize, dsize,
base2k_key, key_base2k,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
@@ -157,21 +138,12 @@ where
for row in 0..gglwe_s0s2.dnum().as_usize() { for row in 0..gglwe_s0s2.dnum().as_usize() {
for col in 0..gglwe_s0s2.rank_in().as_usize() { for col in 0..gglwe_s0s2.rank_in().as_usize() {
assert!( let noise: f64 = gglwe_s0s2
gglwe_s0s2
.key .key
.noise( .noise(module, row, col, &sk0.data, &sk2_prepared, scratch_apply.borrow())
module,
row,
col,
&sk0.data,
&sk2_prepared,
scratch_apply.borrow()
)
.std() .std()
.log2() .log2();
<= max_noise + 0.5 assert!(noise <= max_noise + 0.5, "{noise} > {max_noise}",)
)
} }
} }
} }
@@ -191,25 +163,25 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 17; let out_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let k_out: usize = 102; let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key); let max_dsize: usize = k_out.div_ceil(key_base2k);
for rank_in in 1_usize..3 { for rank_in in 1_usize..3 {
for rank_out in 1_usize..3 { for rank_out in 1_usize..3 {
for dsize in 1_usize..max_dsize + 1 { for dsize in 1_usize..max_dsize + 1 {
let k_ksk: usize = k_out + base2k_key * dsize; let k_ksk: usize = k_out + key_base2k * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dsize_in: usize = 1; let dsize_in: usize = 1;
let dnum_in: usize = k_out / base2k_out; let dnum_in: usize = k_out / out_base2k;
let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let dnum_ksk: usize = k_out.div_ceil(key_base2k * dsize);
let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -219,7 +191,7 @@ where
let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -260,24 +232,10 @@ where
sk2_prepared.prepare(module, &sk2); sk2_prepared.prepare(module, &sk2);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
gglwe_s0s1.encrypt_sk( gglwe_s0s1.encrypt_sk(module, &sk0, &sk1, &mut source_xa, &mut source_xe, scratch_enc.borrow());
module,
&sk0,
&sk1,
&mut source_xa,
&mut source_xe,
scratch_enc.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2 // gglwe_{s2}(s1) -> s1 -> s2
gglwe_s1s2.encrypt_sk( gglwe_s1s2.encrypt_sk(module, &sk1, &sk2, &mut source_xa, &mut source_xe, scratch_enc.borrow());
module,
&sk1,
&sk2,
&mut source_xa,
&mut source_xe,
scratch_enc.borrow(),
);
let mut gglwe_s1s2_prepared: GLWESwitchingKeyPrepared<Vec<u8>, BE> = let mut gglwe_s1s2_prepared: GLWESwitchingKeyPrepared<Vec<u8>, BE> =
GLWESwitchingKeyPrepared::alloc_from_infos(module, &gglwe_s1s2); GLWESwitchingKeyPrepared::alloc_from_infos(module, &gglwe_s1s2);
@@ -290,7 +248,7 @@ where
let max_noise: f64 = log2_std_noise_gglwe_product( let max_noise: f64 = log2_std_noise_gglwe_product(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
var_xs, var_xs,
var_xs, var_xs,
0f64, 0f64,
@@ -303,21 +261,12 @@ where
for row in 0..gglwe_s0s2.dnum().as_usize() { for row in 0..gglwe_s0s2.dnum().as_usize() {
for col in 0..gglwe_s0s2.rank_in().as_usize() { for col in 0..gglwe_s0s2.rank_in().as_usize() {
assert!( let noise = gglwe_s0s2
gglwe_s0s2
.key .key
.noise( .noise(module, row, col, &sk0.data, &sk2_prepared, scratch_apply.borrow())
module,
row,
col,
&sk0.data,
&sk2_prepared,
scratch_apply.borrow()
)
.std() .std()
.log2() .log2();
<= max_noise + 0.5 assert!(noise <= max_noise + 0.5, "{noise} > {max_noise}")
)
} }
} }
} }

View File

@@ -8,9 +8,9 @@ use crate::{
GGLWEToGGSWKeyEncryptSk, GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, GGLWEToGGSWKeyEncryptSk, GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore,
encryption::SIGMA, encryption::SIGMA,
layouts::{ layouts::{
GGLWEToGGSWKey, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWInfos, GGSWLayout, GLWEInfos, GGLWEToGGSWKey, GGLWEToGGSWKeyLayout, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWInfos, GGSWLayout,
GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GLWESwitchingKeyPreparedFactory, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout,
GLWETensorKeyLayout, GLWESwitchingKeyPreparedFactory,
prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared},
}, },
noise::noise_ggsw_keyswitch, noise::noise_ggsw_keyswitch,
@@ -30,27 +30,27 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let base2k_out: usize = base2k_in; // MUST BE SAME let out_base2k: usize = in_base2k; // MUST BE SAME
let k_in: usize = 102; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key); let max_dsize: usize = k_in.div_ceil(key_base2k);
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_in + base2k_key * dsize; let k_ksk: usize = k_in + key_base2k * dsize;
let k_tsk: usize = k_ksk; let k_tsk: usize = k_ksk;
let k_out: usize = k_ksk; // Better capture noise. let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dnum_in: usize = k_in / base2k_in; let dnum_in: usize = k_in / in_base2k;
let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let dnum_ksk: usize = k_in.div_ceil(key_base2k * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_in_infos: GGSWLayout = GGSWLayout { let ggsw_in_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -59,16 +59,16 @@ where
let ggsw_out_infos: GGSWLayout = GGSWLayout { let ggsw_out_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank: rank.into(), rank: rank.into(),
}; };
let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { let tsk_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_tsk.into(), k: k_tsk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -77,7 +77,7 @@ where
let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -99,13 +99,7 @@ where
GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos)
| GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos)
| GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos) | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos)
| GGSW::keyswitch_tmp_bytes( | GGSW::keyswitch_tmp_bytes(module, &ggsw_out_infos, &ggsw_in_infos, &ksk_apply_infos, &tsk_infos),
module,
&ggsw_out_infos,
&ggsw_in_infos,
&ksk_apply_infos,
&tsk_infos,
),
); );
let var_xs: f64 = 0.5; let var_xs: f64 = 0.5;
@@ -122,21 +116,8 @@ where
let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into()); let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into());
sk_out_prepared.prepare(module, &sk_out); sk_out_prepared.prepare(module, &sk_out);
ksk.encrypt_sk( ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
module, tsk.encrypt_sk(module, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
&sk_in,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
tsk.encrypt_sk(
module,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
pt_scalar.fill_ternary_hw(0, n, &mut source_xs); pt_scalar.fill_ternary_hw(0, n, &mut source_xs);
@@ -156,18 +137,12 @@ where
let mut tsk_prepared: GGLWEToGGSWKeyPrepared<Vec<u8>, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); let mut tsk_prepared: GGLWEToGGSWKeyPrepared<Vec<u8>, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk);
tsk_prepared.prepare(module, &tsk, scratch.borrow()); tsk_prepared.prepare(module, &tsk, scratch.borrow());
ggsw_out.keyswitch( ggsw_out.keyswitch(module, &ggsw_in, &ksk_prepared, &tsk_prepared, scratch.borrow());
module,
&ggsw_in,
&ksk_prepared,
&tsk_prepared,
scratch.borrow(),
);
let max_noise = |col_j: usize| -> f64 { let max_noise = |col_j: usize| -> f64 {
noise_ggsw_keyswitch( noise_ggsw_keyswitch(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
col_j, col_j,
var_xs, var_xs,
0f64, 0f64,
@@ -184,14 +159,7 @@ where
for col in 0..ggsw_out.rank().as_usize() + 1 { for col in 0..ggsw_out.rank().as_usize() + 1 {
assert!( assert!(
ggsw_out ggsw_out
.noise( .noise(module, row, col, &pt_scalar, &sk_out_prepared, scratch.borrow())
module,
row,
col,
&pt_scalar,
&sk_out_prepared,
scratch.borrow()
)
.std() .std()
.log2() .log2()
<= max_noise(col) <= max_noise(col)
@@ -216,33 +184,33 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 17; let out_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let k_out: usize = 102; let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key); let max_dsize: usize = k_out.div_ceil(key_base2k);
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_out + base2k_key * dsize; let k_ksk: usize = k_out + key_base2k * dsize;
let k_tsk: usize = k_ksk; let k_tsk: usize = k_ksk;
let n: usize = module.n(); let n: usize = module.n();
let dnum_in: usize = k_out / base2k_out; let dnum_in: usize = k_out / out_base2k;
let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let dnum_ksk: usize = k_out.div_ceil(key_base2k * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_out_infos: GGSWLayout = GGSWLayout { let ggsw_out_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank: rank.into(), rank: rank.into(),
}; };
let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { let tsk_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_tsk.into(), k: k_tsk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -251,7 +219,7 @@ where
let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum_ksk.into(), dnum: dnum_ksk.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -272,13 +240,7 @@ where
GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos)
| GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos)
| GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos) | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos)
| GGSW::keyswitch_tmp_bytes( | GGSW::keyswitch_tmp_bytes(module, &ggsw_out_infos, &ggsw_out_infos, &ksk_apply_infos, &tsk_infos),
module,
&ggsw_out_infos,
&ggsw_out_infos,
&ksk_apply_infos,
&tsk_infos,
),
); );
let var_xs: f64 = 0.5; let var_xs: f64 = 0.5;
@@ -295,21 +257,8 @@ where
let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into()); let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into());
sk_out_prepared.prepare(module, &sk_out); sk_out_prepared.prepare(module, &sk_out);
ksk.encrypt_sk( ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
module, tsk.encrypt_sk(module, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
&sk_in,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
tsk.encrypt_sk(
module,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
pt_scalar.fill_ternary_hw(0, n, &mut source_xs); pt_scalar.fill_ternary_hw(0, n, &mut source_xs);
@@ -334,7 +283,7 @@ where
let max_noise = |col_j: usize| -> f64 { let max_noise = |col_j: usize| -> f64 {
noise_ggsw_keyswitch( noise_ggsw_keyswitch(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
col_j, col_j,
var_xs, var_xs,
0f64, 0f64,
@@ -351,14 +300,7 @@ where
for col in 0..ggsw_out.rank().as_usize() + 1 { for col in 0..ggsw_out.rank().as_usize() + 1 {
assert!( assert!(
ggsw_out ggsw_out
.noise( .noise(module, row, col, &pt_scalar, &sk_out_prepared, scratch.borrow())
module,
row,
col,
&pt_scalar,
&sk_out_prepared,
scratch.borrow()
)
.std() .std()
.log2() .log2()
<= max_noise(col) <= max_noise(col)

View File

@@ -29,38 +29,38 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let base2k_out: usize = 15; let out_base2k: usize = 15;
let k_in: usize = 102; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key); let max_dsize: usize = k_in.div_ceil(key_base2k);
for rank_in in 1_usize..3 { for rank_in in 1_usize..3 {
for rank_out in 1_usize..3 { for rank_out in 1_usize..3 {
for dsize in 1_usize..max_dsize + 1 { for dsize in 1_usize..max_dsize + 1 {
let k_ksk: usize = k_in + base2k_key * dsize; let k_ksk: usize = k_in + key_base2k * dsize;
let k_out: usize = k_ksk; // better capture noise let k_out: usize = k_ksk; // better capture noise
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_in.div_ceil(base2k_key * dsize); let dnum: usize = k_in.div_ceil(key_base2k * dsize);
let glwe_in_infos: GLWELayout = GLWELayout { let glwe_in_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_in.into(), k: k_in.into(),
rank: rank_in.into(), rank: rank_in.into(),
}; };
let glwe_out_infos: GLWELayout = GLWELayout { let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
rank: rank_out.into(), rank: rank_out.into(),
}; };
let ksk: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let ksk: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -98,14 +98,7 @@ where
let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank_out.into());
sk_out_prepared.prepare(module, &sk_out); sk_out_prepared.prepare(module, &sk_out);
ksk.encrypt_sk( ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&sk_in,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
glwe_in.encrypt_sk( glwe_in.encrypt_sk(
module, module,
@@ -127,7 +120,7 @@ where
k_ksk, k_ksk,
dnum, dnum,
dsize, dsize,
base2k_key, key_base2k,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
@@ -164,28 +157,28 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 17; let out_base2k: usize = 17;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let k_out: usize = 102; let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key); let max_dsize: usize = k_out.div_ceil(key_base2k);
for rank in 1_usize..3 { for rank in 1_usize..3 {
for dsize in 1..max_dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_out + base2k_key * dsize; let k_ksk: usize = k_out + key_base2k * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_out.div_ceil(base2k_key * dsize); let dnum: usize = k_out.div_ceil(key_base2k * dsize);
let glwe_out_infos: GLWELayout = GLWELayout { let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k_out.into(), k: k_out.into(),
rank: rank.into(), rank: rank.into(),
}; };
let ksk_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let ksk_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -201,12 +194,7 @@ where
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
module.vec_znx_fill_uniform( module.vec_znx_fill_uniform(pt_want.base2k().into(), &mut pt_want.data, 0, &mut source_xa);
pt_want.base2k().into(),
&mut pt_want.data,
0,
&mut source_xa,
);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_infos) GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_infos)
@@ -226,14 +214,7 @@ where
let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into()); let mut sk_out_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into());
sk_out_prepared.prepare(module, &sk_out); sk_out_prepared.prepare(module, &sk_out);
ksk.encrypt_sk( ksk.encrypt_sk(module, &sk_in, &sk_out, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&sk_in,
&sk_out,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
glwe_out.encrypt_sk( glwe_out.encrypt_sk(
module, module,
@@ -255,7 +236,7 @@ where
k_ksk, k_ksk,
dnum, dnum,
dsize, dsize,
base2k_key, key_base2k,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,

View File

@@ -24,17 +24,17 @@ where
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let n: usize = module.n(); let n: usize = module.n();
let base2k_in: usize = 17; let in_base2k: usize = 17;
let base2k_out: usize = 15; let out_base2k: usize = 15;
let base2k_key: usize = 13; let key_base2k: usize = 13;
let n_lwe_in: usize = module.n() >> 1; let n_lwe_in: usize = module.n() >> 1;
let n_lwe_out: usize = module.n() >> 1; let n_lwe_out: usize = module.n() >> 1;
let k_lwe_ct: usize = 102; let k_lwe_ct: usize = 102;
let k_lwe_pt: usize = 8; let k_lwe_pt: usize = 8;
let k_ksk: usize = k_lwe_ct + base2k_key; let k_ksk: usize = k_lwe_ct + key_base2k;
let dnum: usize = k_lwe_ct.div_ceil(base2k_key); let dnum: usize = k_lwe_ct.div_ceil(key_base2k);
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
@@ -42,21 +42,21 @@ where
let key_apply_infos: LWESwitchingKeyLayout = LWESwitchingKeyLayout { let key_apply_infos: LWESwitchingKeyLayout = LWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum.into(), dnum: dnum.into(),
}; };
let lwe_in_infos: LWELayout = LWELayout { let lwe_in_infos: LWELayout = LWELayout {
n: n_lwe_in.into(), n: n_lwe_in.into(),
base2k: base2k_in.into(), base2k: in_base2k.into(),
k: k_lwe_ct.into(), k: k_lwe_ct.into(),
}; };
let lwe_out_infos: LWELayout = LWELayout { let lwe_out_infos: LWELayout = LWELayout {
n: n_lwe_out.into(), n: n_lwe_out.into(),
k: k_lwe_ct.into(), k: k_lwe_ct.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
}; };
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
@@ -72,7 +72,7 @@ where
let data: i64 = 17; let data: i64 = 17;
let mut lwe_pt_in: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(base2k_in.into(), k_lwe_pt.into()); let mut lwe_pt_in: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(in_base2k.into(), k_lwe_pt.into());
lwe_pt_in.encode_i64(data, k_lwe_pt.into()); lwe_pt_in.encode_i64(data, k_lwe_pt.into());
let mut lwe_ct_in: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_in_infos); let mut lwe_ct_in: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_in_infos);
@@ -108,11 +108,12 @@ where
let mut lwe_pt_want: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc_from_infos(&lwe_out_infos); let mut lwe_pt_want: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc_from_infos(&lwe_out_infos);
module.vec_znx_normalize( module.vec_znx_normalize(
base2k_out,
lwe_pt_want.data_mut(), lwe_pt_want.data_mut(),
out_base2k,
0,
0, 0,
base2k_in,
lwe_pt_in.data(), lwe_pt_in.data(),
in_base2k,
0, 0,
scratch.borrow(), scratch.borrow(),
); );

View File

@@ -1,6 +1,7 @@
pub mod automorphism; pub mod automorphism;
pub mod encryption; pub mod encryption;
pub mod external_product; pub mod external_product;
pub mod glwe_tensor;
pub mod keyswitch; pub mod keyswitch;
mod conversion; mod conversion;

View File

@@ -32,27 +32,27 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k_out: usize = 15; let out_base2k: usize = 15;
let base2k_key: usize = 10; let key_base2k: usize = 10;
let k: usize = 54; let k: usize = 54;
for rank in 1_usize..3 { for rank in 1_usize..3 {
let n: usize = module.n(); let n: usize = module.n();
let k_autokey: usize = k + base2k_key; let k_autokey: usize = k + key_base2k;
let dsize: usize = 1; let dsize: usize = 1;
let dnum: usize = k.div_ceil(base2k_key * dsize); let dnum: usize = k.div_ceil(key_base2k * dsize);
let glwe_out_infos: GLWELayout = GLWELayout { let glwe_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k_out.into(), base2k: out_base2k.into(),
k: k.into(), k: k.into(),
rank: rank.into(), rank: rank.into(),
}; };
let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k_key.into(), base2k: key_base2k.into(),
k: k_autokey.into(), k: k_autokey.into(),
rank: rank.into(), rank: rank.into(),
dsize: dsize.into(), dsize: dsize.into(),
@@ -82,33 +82,17 @@ where
let mut data_want: Vec<i64> = vec![0i64; n]; let mut data_want: Vec<i64> = vec![0i64; n];
data_want data_want.iter_mut().for_each(|x| *x = source_xa.next_i64() & 0xFF);
.iter_mut()
.for_each(|x| *x = source_xa.next_i64() & 0xFF);
module.vec_znx_fill_uniform(base2k_out, &mut pt_have.data, 0, &mut source_xa); module.vec_znx_fill_uniform(out_base2k, &mut pt_have.data, 0, &mut source_xa);
glwe_out.encrypt_sk( glwe_out.encrypt_sk(module, &pt_have, &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow());
module,
&pt_have,
&sk_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut auto_keys: HashMap<i64, GLWEAutomorphismKeyPrepared<Vec<u8>, BE>> = HashMap::new(); let mut auto_keys: HashMap<i64, GLWEAutomorphismKeyPrepared<Vec<u8>, BE>> = HashMap::new();
let gal_els: Vec<i64> = GLWE::trace_galois_elements(module); let gal_els: Vec<i64> = GLWE::trace_galois_elements(module);
let mut tmp: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&key_infos); let mut tmp: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&key_infos);
gal_els.iter().for_each(|gal_el| { gal_els.iter().for_each(|gal_el| {
tmp.encrypt_sk( tmp.encrypt_sk(module, *gal_el, &sk, &mut source_xa, &mut source_xe, scratch.borrow());
module,
*gal_el,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut atk_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> = let mut atk_prepared: GLWEAutomorphismKeyPrepared<Vec<u8>, BE> =
GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp); GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp);
atk_prepared.prepare(module, &tmp, scratch.borrow()); atk_prepared.prepare(module, &tmp, scratch.borrow());
@@ -122,18 +106,13 @@ where
glwe_out.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); glwe_out.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_inplace(&mut pt_want.data, 0, &pt_have.data, 0); module.vec_znx_sub_inplace(&mut pt_want.data, 0, &pt_have.data, 0);
module.vec_znx_normalize_inplace( module.vec_znx_normalize_inplace(pt_want.base2k().as_usize(), &mut pt_want.data, 0, scratch.borrow());
pt_want.base2k().as_usize(),
&mut pt_want.data,
0,
scratch.borrow(),
);
let noise_have: f64 = pt_want.stats().std().log2(); let noise_have: f64 = pt_want.stats().std().log2();
let mut noise_want: f64 = var_noise_gglwe_product( let mut noise_want: f64 = var_noise_gglwe_product(
n as f64, n as f64,
base2k_key * dsize, key_base2k * dsize,
0.5, 0.5,
0.5, 0.5,
1.0 / 12.0, 1.0 / 12.0,
@@ -147,9 +126,6 @@ where
noise_want += n as f64 * 1.0 / 12.0 * 0.5 * rank as f64 * (-2.0 * (k) as f64).exp2(); noise_want += n as f64 * 1.0 / 12.0 * 0.5 * rank as f64 * (-2.0 * (k) as f64).exp2();
noise_want = noise_want.sqrt().log2(); noise_want = noise_want.sqrt().log2();
assert!( assert!((noise_have - noise_want).abs() < 1.0, "{noise_have} > {noise_want}");
(noise_have - noise_want).abs() < 1.0,
"{noise_have} > {noise_want}"
);
} }
} }

View File

@@ -16,13 +16,11 @@ impl<D: DataMut> GLWEPlaintext<D> {
impl<D: DataRef> GLWEPlaintext<D> { impl<D: DataRef> GLWEPlaintext<D> {
pub fn decode_vec_i64(&self, data: &mut [i64], k: TorusPrecision) { pub fn decode_vec_i64(&self, data: &mut [i64], k: TorusPrecision) {
self.data self.data.decode_vec_i64(self.base2k().into(), 0, k.into(), data);
.decode_vec_i64(self.base2k().into(), 0, k.into(), data);
} }
pub fn decode_coeff_i64(&self, k: TorusPrecision, idx: usize) -> i64 { pub fn decode_coeff_i64(&self, k: TorusPrecision, idx: usize) -> i64 {
self.data self.data.decode_coeff_i64(self.base2k().into(), 0, k.into(), idx)
.decode_coeff_i64(self.base2k().into(), 0, k.into(), idx)
} }
pub fn decode_vec_float(&self, data: &mut [Float]) { pub fn decode_vec_float(&self, data: &mut [Float]) {
@@ -43,14 +41,12 @@ impl<D: DataMut> LWEPlaintext<D> {
impl<D: DataRef> LWEPlaintext<D> { impl<D: DataRef> LWEPlaintext<D> {
pub fn decode_i64(&self, k: TorusPrecision) -> i64 { pub fn decode_i64(&self, k: TorusPrecision) -> i64 {
self.data self.data.decode_coeff_i64(self.base2k().into(), 0, k.into(), 0)
.decode_coeff_i64(self.base2k().into(), 0, k.into(), 0)
} }
pub fn decode_float(&self) -> Float { pub fn decode_float(&self) -> Float {
let mut out: [Float; 1] = [Float::new(self.k().as_u32())]; let mut out: [Float; 1] = [Float::new(self.k().as_u32())];
self.data self.data.decode_vec_float(self.base2k().into(), 0, &mut out);
.decode_vec_float(self.base2k().into(), 0, &mut out);
out[0].clone() out[0].clone()
} }
} }

View File

@@ -32,5 +32,5 @@ rustdoc-args = ["--cfg", "docsrs"]
[[bench]] [[bench]]
name = "vmp" name = "convolution"
harness = false harness = false

View File

@@ -0,0 +1,116 @@
use criterion::{Criterion, criterion_group, criterion_main};
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_prepare_left_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_prepare_left_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_prepare_left::<FFT64Avx>(c, "cpu_avx::fft64");
}
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_prepare_right_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_prepare_right_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_prepare_right::<FFT64Avx>(c, "cpu_avx::fft64");
}
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_apply_dft_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_apply_dft_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_apply_dft::<FFT64Avx>(c, "cpu_avx::fft64");
}
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_pairwise_apply_dft_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_pairwise_apply_dft_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_pairwise_apply_dft::<FFT64Avx>(c, "cpu_avx::fft64");
}
#[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_cnv_by_const_apply_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
}
#[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_cnv_by_const_apply_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::convolution::bench_cnv_by_const_apply::<FFT64Avx>(c, "cpu_avx::fft64");
}
criterion_group!(
benches,
bench_cnv_prepare_left_cpu_avx_fft64,
bench_cnv_prepare_right_cpu_avx_fft64,
bench_cnv_apply_dft_cpu_avx_fft64,
bench_cnv_pairwise_apply_dft_cpu_avx_fft64,
bench_cnv_by_const_apply_cpu_avx_fft64,
);
criterion_main!(benches);

View File

@@ -1,11 +1,21 @@
use criterion::{Criterion, criterion_group, criterion_main}; use criterion::{Criterion, criterion_group, criterion_main};
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] #[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_ifft_avx2_fma(_c: &mut Criterion) { fn bench_ifft_avx2_fma(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
} }
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] #[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
pub fn bench_ifft_avx2_fma(c: &mut Criterion) { pub fn bench_ifft_avx2_fma(c: &mut Criterion) {
use criterion::BenchmarkId; use criterion::BenchmarkId;
use poulpy_cpu_avx::ReimIFFTAvx; use poulpy_cpu_avx::ReimIFFTAvx;
@@ -21,10 +31,7 @@ pub fn bench_ifft_avx2_fma(c: &mut Criterion) {
let mut values: Vec<f64> = vec![0f64; m << 1]; let mut values: Vec<f64> = vec![0f64; m << 1];
let scale = 1.0f64 / (2 * m) as f64; let scale = 1.0f64 / (2 * m) as f64;
values values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m); let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m);
move || { move || {
@@ -47,12 +54,22 @@ pub fn bench_ifft_avx2_fma(c: &mut Criterion) {
group.finish(); group.finish();
} }
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] #[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_fft_avx2_fma(_c: &mut Criterion) { fn bench_fft_avx2_fma(_c: &mut Criterion) {
eprintln!("Skipping: AVX FFT benchmark requires x86_64 + AVX2 + FMA"); eprintln!("Skipping: AVX FFT benchmark requires x86_64 + AVX2 + FMA");
} }
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] #[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
pub fn bench_fft_avx2_fma(c: &mut Criterion) { pub fn bench_fft_avx2_fma(c: &mut Criterion) {
use criterion::BenchmarkId; use criterion::BenchmarkId;
use poulpy_cpu_avx::ReimFFTAvx; use poulpy_cpu_avx::ReimFFTAvx;
@@ -68,10 +85,7 @@ pub fn bench_fft_avx2_fma(c: &mut Criterion) {
let mut values: Vec<f64> = vec![0f64; m << 1]; let mut values: Vec<f64> = vec![0f64; m << 1];
let scale = 1.0f64 / (2 * m) as f64; let scale = 1.0f64 / (2 * m) as f64;
values values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m); let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m);
move || { move || {

View File

@@ -1,33 +1,63 @@
use criterion::{Criterion, criterion_group, criterion_main}; use criterion::{Criterion, criterion_group, criterion_main};
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] #[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_vec_znx_add_cpu_avx_fft64(_c: &mut Criterion) { fn bench_vec_znx_add_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
} }
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] #[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_vec_znx_add_cpu_avx_fft64(c: &mut Criterion) { fn bench_vec_znx_add_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx; use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::reference::vec_znx::bench_vec_znx_add::<FFT64Avx>(c, "FFT64Avx"); poulpy_hal::reference::vec_znx::bench_vec_znx_add::<FFT64Avx>(c, "FFT64Avx");
} }
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] #[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_vec_znx_normalize_inplace_cpu_avx_fft64(_c: &mut Criterion) { fn bench_vec_znx_normalize_inplace_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
} }
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] #[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_vec_znx_normalize_inplace_cpu_avx_fft64(c: &mut Criterion) { fn bench_vec_znx_normalize_inplace_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx; use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::reference::vec_znx::bench_vec_znx_normalize_inplace::<FFT64Avx>(c, "FFT64Avx"); poulpy_hal::reference::vec_znx::bench_vec_znx_normalize_inplace::<FFT64Avx>(c, "FFT64Avx");
} }
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] #[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_vec_znx_automorphism_cpu_avx_fft64(_c: &mut Criterion) { fn bench_vec_znx_automorphism_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
} }
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] #[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_vec_znx_automorphism_cpu_avx_fft64(c: &mut Criterion) { fn bench_vec_znx_automorphism_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx; use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::reference::vec_znx::bench_vec_znx_automorphism::<FFT64Avx>(c, "FFT64Avx"); poulpy_hal::reference::vec_znx::bench_vec_znx_automorphism::<FFT64Avx>(c, "FFT64Avx");

View File

@@ -1,11 +1,21 @@
use criterion::{Criterion, criterion_group, criterion_main}; use criterion::{Criterion, criterion_group, criterion_main};
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] #[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(_c: &mut Criterion) { fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(_c: &mut Criterion) {
eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA"); eprintln!("Skipping: AVX IFft benchmark requires x86_64 + AVX2 + FMA");
} }
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] #[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(c: &mut Criterion) { fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(c: &mut Criterion) {
use poulpy_cpu_avx::FFT64Avx; use poulpy_cpu_avx::FFT64Avx;
poulpy_hal::bench_suite::vmp::bench_vmp_apply_dft_to_dft::<FFT64Avx>(c, "FFT64Avx"); poulpy_hal::bench_suite::vmp::bench_vmp_apply_dft_to_dft::<FFT64Avx>(c, "FFT64Avx");

View File

@@ -1,8 +1,18 @@
use itertools::izip; use itertools::izip;
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] #[cfg(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
))]
use poulpy_cpu_avx::FFT64Avx as BackendImpl; use poulpy_cpu_avx::FFT64Avx as BackendImpl;
#[cfg(not(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma")))] #[cfg(not(all(
feature = "enable-avx",
target_arch = "x86_64",
target_feature = "avx2",
target_feature = "fma"
)))]
use poulpy_cpu_ref::FFT64Ref as BackendImpl; use poulpy_cpu_ref::FFT64Ref as BackendImpl;
use poulpy_hal::{ use poulpy_hal::{
@@ -73,8 +83,7 @@ fn main() {
msg_size, // Number of small polynomials msg_size, // Number of small polynomials
); );
let mut want: Vec<i64> = vec![0; n]; let mut want: Vec<i64> = vec![0; n];
want.iter_mut() want.iter_mut().for_each(|x| *x = source.next_u64n(16, 15) as i64);
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
m.encode_vec_i64(base2k, 0, log_scale, &want); m.encode_vec_i64(base2k, 0, log_scale, &want);
module.vec_znx_normalize_inplace(base2k, &mut m, 0, scratch.borrow()); module.vec_znx_normalize_inplace(base2k, &mut m, 0, scratch.borrow());
@@ -89,11 +98,12 @@ fn main() {
// Normalizes back to VecZnx // Normalizes back to VecZnx
// ct[0] <- m - BIG(c1 * s) // ct[0] <- m - BIG(c1 * s)
module.vec_znx_big_normalize( module.vec_znx_big_normalize(
base2k,
&mut ct, &mut ct,
0, // Selects the first column of ct (ct[0])
base2k, base2k,
0,
0, // Selects the first column of ct (ct[0])
&buf_big, &buf_big,
base2k,
0, // Selects the first column of buf_big 0, // Selects the first column of buf_big
scratch.borrow(), scratch.borrow(),
); );
@@ -131,15 +141,13 @@ fn main() {
// m + e <- BIG(ct[1] * s + ct[0]) // m + e <- BIG(ct[1] * s + ct[0])
let mut res = VecZnx::alloc(module.n(), 1, ct_size); let mut res = VecZnx::alloc(module.n(), 1, ct_size);
module.vec_znx_big_normalize(base2k, &mut res, 0, base2k, &buf_big, 0, scratch.borrow()); module.vec_znx_big_normalize(&mut res, base2k, 0, 0, &buf_big, base2k, 0, scratch.borrow());
// have = m * 2^{log_scale} + e // have = m * 2^{log_scale} + e
let mut have: Vec<i64> = vec![i64::default(); n]; let mut have: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(base2k, 0, ct_size * base2k, &mut have); res.decode_vec_i64(base2k, 0, ct_size * base2k, &mut have);
let scale: f64 = (1 << (res.size() * base2k - log_scale)) as f64; let scale: f64 = (1 << (res.size() * base2k - log_scale)) as f64;
izip!(want.iter(), have.iter()) izip!(want.iter(), have.iter()).enumerate().for_each(|(i, (a, b))| {
.enumerate()
.for_each(|(i, (a, b))| {
println!("{}: {} {}", i, a, (*b as f64) / scale); println!("{}: {} {}", i, a, (*b as f64) / scale);
}); });
} }

View File

@@ -0,0 +1,401 @@
use poulpy_hal::{
api::{Convolution, ModuleN, ScratchTakeBasic, TakeSlice, VecZnxDftApply, VecZnxDftBytesOf},
layouts::{
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnx,
VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos,
},
oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl},
reference::fft64::convolution::{
convolution_apply_dft, convolution_apply_dft_tmp_bytes, convolution_by_const_apply, convolution_by_const_apply_tmp_bytes,
convolution_pairwise_apply_dft, convolution_pairwise_apply_dft_tmp_bytes, convolution_prepare_left,
convolution_prepare_right,
},
};
use crate::{FFT64Avx, module::FFT64ModuleHandle};
unsafe impl CnvPVecLAllocImpl<Self> for FFT64Avx {
fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, Self> {
CnvPVecL::alloc(n, cols, size)
}
fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, Self> {
CnvPVecR::alloc(n, cols, size)
}
}
unsafe impl CnvPVecBytesOfImpl for FFT64Avx {
fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Avx as Backend>::ScalarPrep>()
}
fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Avx as Backend>::ScalarPrep>()
}
}
unsafe impl ConvolutionImpl<Self> for FFT64Avx
where
Module<Self>: ModuleN + VecZnxDftBytesOf + VecZnxDftApply<Self>,
{
fn cnv_prepare_left_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
}
fn cnv_prepare_left_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
where
R: CnvPVecLToMut<Self>,
A: VecZnxToRef,
{
let res: &mut CnvPVecL<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
convolution_prepare_left(module.get_fft_table(), res, a, &mut tmp);
}
fn cnv_prepare_right_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
}
fn cnv_prepare_right_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
where
R: CnvPVecRToMut<Self>,
A: VecZnxToRef,
{
let res: &mut CnvPVecR<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
convolution_prepare_right(module.get_fft_table(), res, a, &mut tmp);
}
fn cnv_apply_dft_tmp_bytes_impl(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_apply_dft_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_by_const_apply_tmp_bytes_impl(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_by_const_apply_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_by_const_apply_impl<R, A>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &[i64],
scratch: &mut Scratch<Self>,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
{
let res: &mut VecZnxBig<&mut [u8], Self> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (tmp, _) =
scratch.take_slice(module.cnv_by_const_apply_tmp_bytes(res.size(), res_offset, a.size(), b.len()) / size_of::<i64>());
convolution_by_const_apply(res, res_offset, res_col, a, a_col, b, tmp);
}
fn cnv_apply_dft_impl<R, A, B>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxDftToMut<Self>,
A: CnvPVecLToRef<Self>,
B: CnvPVecRToRef<Self>,
{
let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
let (tmp, _) =
scratch.take_slice(module.cnv_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
convolution_apply_dft(res, res_offset, res_col, a, a_col, b, b_col, tmp);
}
fn cnv_pairwise_apply_dft_tmp_bytes(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_pairwise_apply_dft_impl<R, A, B>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
b: &B,
i: usize,
j: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxDftToMut<Self>,
A: CnvPVecLToRef<Self>,
B: CnvPVecRToRef<Self>,
{
let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
let (tmp, _) = scratch
.take_slice(module.cnv_pairwise_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
convolution_pairwise_apply_dft(res, res_offset, res_col, a, b, i, j, tmp);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2.
/// Assumes all inputs fit in i32 (so i32×i32→i64 is exact).
#[target_feature(enable = "avx2")]
pub unsafe fn i64_convolution_by_const_1coeff_avx(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
use core::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
_mm256_storeu_si256,
};
dst.fill(0);
let b_size = b.len();
if k >= a_size + b_size {
return;
}
let j_min = k.saturating_sub(a_size - 1);
let j_max = (k + 1).min(b_size);
unsafe {
// Two accumulators = 8 outputs total
let mut acc_lo: __m256i = _mm256_setzero_si256(); // dst[0..4)
let mut acc_hi: __m256i = _mm256_setzero_si256(); // dst[4..8)
let mut a_ptr: *const i64 = a.as_ptr().add(8 * (k - j_min));
let mut b_ptr: *const i64 = b.as_ptr().add(j_min);
for _ in 0..(j_max - j_min) {
// Broadcast scalar b[j] as i32
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
// ---- lower half: a[0..4) ----
let a_lo: __m256i = _mm256_loadu_si256(a_ptr as *const __m256i);
let prod_lo: __m256i = _mm256_mul_epi32(a_lo, br);
acc_lo = _mm256_add_epi64(acc_lo, prod_lo);
// ---- upper half: a[4..8) ----
let a_hi: __m256i = _mm256_loadu_si256(a_ptr.add(4) as *const __m256i);
let prod_hi: __m256i = _mm256_mul_epi32(a_hi, br);
acc_hi = _mm256_add_epi64(acc_hi, prod_hi);
a_ptr = a_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
// Store final result
_mm256_storeu_si256(dst.as_mut_ptr() as *mut __m256i, acc_lo);
_mm256_storeu_si256(dst.as_mut_ptr().add(4) as *mut __m256i, acc_hi);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2.
/// Assumes all values in `a` and `b` fit in i32 (so i32×i32→i64 is exact).
#[target_feature(enable = "avx2")]
pub unsafe fn i64_convolution_by_real_const_2coeffs_avx(
k: usize,
dst: &mut [i64; 16],
a: &[i64],
a_size: usize,
b: &[i64], // real scalars, stride-1
) {
use core::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
_mm256_storeu_si256,
};
let b_size: usize = b.len();
debug_assert!(a.len() >= 8 * a_size);
let k0: usize = k;
let k1: usize = k + 1;
let bound: usize = a_size + b_size;
if k0 >= bound {
unsafe {
let zero: __m256i = _mm256_setzero_si256();
let dst_ptr: *mut i64 = dst.as_mut_ptr();
_mm256_storeu_si256(dst_ptr as *mut __m256i, zero);
_mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, zero);
_mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, zero);
_mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, zero);
}
return;
}
unsafe {
let mut acc_lo_k0: __m256i = _mm256_setzero_si256();
let mut acc_hi_k0: __m256i = _mm256_setzero_si256();
let mut acc_lo_k1: __m256i = _mm256_setzero_si256();
let mut acc_hi_k1: __m256i = _mm256_setzero_si256();
let j0_min: usize = (k0 + 1).saturating_sub(a_size);
let j0_max: usize = (k0 + 1).min(b_size);
if k1 >= bound {
let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);
// Contributions to k0 only
for _ in 0..j0_max - j0_min {
// Broadcast b[j] as i32
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
// Load 4×i64 (low half) and 4×i64 (high half)
let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
} else {
let j1_min: usize = (k1 + 1).saturating_sub(a_size);
let j1_max: usize = (k1 + 1).min(b_size);
let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut a_k1_ptr: *const i64 = a.as_ptr().add(8 * (k1 - j1_min));
let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);
// Region 1: k0 only, j ∈ [j0_min, j1_min)
for _ in 0..j1_min - j0_min {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_k0_lo: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
let a_k0_hi: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_k0_lo, br));
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_k0_hi, br));
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
// Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max)
// Save one load on b: broadcast once and reuse.
for _ in 0..j0_max - j1_min {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);
// k0
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));
// k1
acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));
a_k0_ptr = a_k0_ptr.sub(8);
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
// Region 3: k1 only, j ∈ [j0_max, j1_max)
for _ in 0..j1_max - j0_max {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);
acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
}
let dst_ptr: *mut i64 = dst.as_mut_ptr();
_mm256_storeu_si256(dst_ptr as *mut __m256i, acc_lo_k0);
_mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, acc_hi_k0);
_mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, acc_lo_k1);
_mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, acc_hi_k1);
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx")]
pub fn i64_extract_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
unsafe {
let mut src_ptr: *const __m256i = src.as_ptr().add(offset + (blk << 3)) as *const __m256i; // src + 8*blk
let mut dst_ptr: *mut __m256i = dst.as_mut_ptr() as *mut __m256i;
let step: usize = n >> 2;
// Each iteration copies 8 i64; advance src by n i64 each row
for _ in 0..rows {
let v: __m256i = _mm256_loadu_si256(src_ptr);
_mm256_storeu_si256(dst_ptr, v);
let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
_mm256_storeu_si256(dst_ptr.add(1), v);
dst_ptr = dst_ptr.add(2);
src_ptr = src_ptr.add(step);
}
}
}
/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx")]
pub fn i64_save_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
unsafe {
let mut src_ptr: *const __m256i = src.as_ptr() as *const __m256i;
let mut dst_ptr: *mut __m256i = dst.as_mut_ptr().add(offset + (blk << 3)) as *mut __m256i; // dst + 8*blk
let step: usize = n >> 2;
// Each iteration copies 8 i64; advance dst by n i64 each row
for _ in 0..rows {
let v: __m256i = _mm256_loadu_si256(src_ptr);
_mm256_storeu_si256(dst_ptr, v);
let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
_mm256_storeu_si256(dst_ptr.add(1), v);
dst_ptr = dst_ptr.add(step);
src_ptr = src_ptr.add(2);
}
}
}

View File

@@ -1,7 +1,7 @@
// ───────────────────────────────────────────────────────────── // ─────────────────────────────────────────────────────────────
// Build the backend **only when ALL conditions are satisfied** // Build the backend **only when ALL conditions are satisfied**
// ───────────────────────────────────────────────────────────── // ─────────────────────────────────────────────────────────────
#![cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))] //#![cfg(all(feature = "enable-avx", target_arch = "x86_64", target_feature = "avx2", target_feature = "fma"))]
// If the user enables this backend but targets a non-x86_64 CPU → abort // If the user enables this backend but targets a non-x86_64 CPU → abort
#[cfg(all(feature = "enable-avx", not(target_arch = "x86_64")))] #[cfg(all(feature = "enable-avx", not(target_arch = "x86_64")))]
@@ -15,6 +15,7 @@ compile_error!("feature `enable-avx` requires AVX2. Build with RUSTFLAGS=\"-C ta
#[cfg(all(feature = "enable-avx", target_arch = "x86_64", not(target_feature = "fma")))] #[cfg(all(feature = "enable-avx", target_arch = "x86_64", not(target_feature = "fma")))]
compile_error!("feature `enable-avx` requires FMA. Build with RUSTFLAGS=\"-C target-feature=+fma\"."); compile_error!("feature `enable-avx` requires FMA. Build with RUSTFLAGS=\"-C target-feature=+fma\".");
mod convolution;
mod module; mod module;
mod reim; mod reim;
mod reim4; mod reim4;

View File

@@ -5,13 +5,18 @@ use poulpy_hal::{
oep::ModuleNewImpl, oep::ModuleNewImpl,
reference::{ reference::{
fft64::{ fft64::{
convolution::{
I64ConvolutionByConst1Coeff, I64ConvolutionByConst2Coeffs, I64Extract1BlkContiguous, I64Save1BlkContiguous,
},
reim::{ reim::{
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul, ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx, ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref, ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref,
}, },
reim4::{ reim4::{
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks, Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4ConvolutionByRealConst1Coeff,
Reim4ConvolutionByRealConst2Coeffs, Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd,
Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save1BlkContiguous, Reim4Save2Blks,
}, },
}, },
znx::{ znx::{
@@ -26,6 +31,10 @@ use poulpy_hal::{
use crate::{ use crate::{
FFT64Avx, FFT64Avx,
convolution::{
i64_convolution_by_const_1coeff_avx, i64_convolution_by_real_const_2coeffs_avx, i64_extract_1blk_contiguous_avx,
i64_save_1blk_contiguous_avx,
},
reim::{ reim::{
ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma, ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma,
reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, reim_sub_avx2_fma, reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, reim_sub_avx2_fma,
@@ -33,8 +42,10 @@ use crate::{
}, },
reim_to_znx_i64_bnd63_avx2_fma, reim_to_znx_i64_bnd63_avx2_fma,
reim4::{ reim4::{
reim4_extract_1blk_from_reim_avx, reim4_save_1blk_to_reim_avx, reim4_save_2blk_to_reim_avx, reim4_convolution_1coeff_avx, reim4_convolution_2coeffs_avx, reim4_convolution_by_real_const_1coeff_avx,
reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx, reim4_convolution_by_real_const_2coeffs_avx, reim4_extract_1blk_from_reim_contiguous_avx, reim4_save_1blk_to_reim_avx,
reim4_save_1blk_to_reim_contiguous_avx, reim4_save_2blk_to_reim_avx, reim4_vec_mat1col_product_avx,
reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx,
}, },
znx_avx::{ znx_avx::{
znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_extract_digit_addmul_avx, znx_mul_add_power_of_two_avx, znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_extract_digit_addmul_avx, znx_mul_add_power_of_two_avx,
@@ -470,11 +481,55 @@ impl ReimZero for FFT64Avx {
} }
} }
impl Reim4Extract1Blk for FFT64Avx { impl Reim4Convolution1Coeff for FFT64Avx {
#[inline(always)] #[inline(always)]
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
unsafe { unsafe {
reim4_extract_1blk_from_reim_avx(m, rows, blk, dst, src); reim4_convolution_1coeff_avx(k, dst, a, a_size, b, b_size);
}
}
}
impl Reim4Convolution2Coeffs for FFT64Avx {
#[inline(always)]
fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
unsafe {
reim4_convolution_2coeffs_avx(k, dst, a, a_size, b, b_size);
}
}
}
impl Reim4ConvolutionByRealConst1Coeff for FFT64Avx {
#[inline(always)]
fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) {
unsafe {
reim4_convolution_by_real_const_1coeff_avx(k, dst, a, a_size, b);
}
}
}
impl Reim4ConvolutionByRealConst2Coeffs for FFT64Avx {
#[inline(always)]
fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) {
unsafe {
reim4_convolution_by_real_const_2coeffs_avx(k, dst, a, a_size, b);
}
}
}
impl Reim4Extract1BlkContiguous for FFT64Avx {
#[inline(always)]
fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
unsafe {
reim4_extract_1blk_from_reim_contiguous_avx(m, rows, blk, dst, src);
}
}
}
impl Reim4Save1BlkContiguous for FFT64Avx {
fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
unsafe {
reim4_save_1blk_to_reim_contiguous_avx(m, rows, blk, dst, src);
} }
} }
} }
@@ -523,3 +578,39 @@ impl Reim4Mat2Cols2ndColProd for FFT64Avx {
} }
} }
} }
impl I64ConvolutionByConst1Coeff for FFT64Avx {
#[inline(always)]
fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
unsafe {
i64_convolution_by_const_1coeff_avx(k, dst, a, a_size, b);
}
}
}
impl I64ConvolutionByConst2Coeffs for FFT64Avx {
#[inline(always)]
fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
unsafe {
i64_convolution_by_real_const_2coeffs_avx(k, dst, a, a_size, b);
}
}
}
impl I64Save1BlkContiguous for FFT64Avx {
#[inline(always)]
fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
unsafe {
i64_save_1blk_contiguous_avx(n, offset, rows, blk, dst, src);
}
}
}
impl I64Extract1BlkContiguous for FFT64Avx {
#[inline(always)]
fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
unsafe {
i64_extract_1blk_contiguous_avx(n, offset, rows, blk, dst, src);
}
}
}

Some files were not shown because too many files have changed in this diff Show More