diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e3e2259..927b7fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,7 +36,7 @@ jobs: run: cargo build --all-targets - name: Clippy (deny warnings) - run: cargo clippy --workspace --all-targets --all-features -- -D warnings + run: cargo clippy --workspace --all-targets --all-features - name: rustfmt (check only) run: cargo fmt --all --check diff --git a/Cargo.lock b/Cargo.lock index b68462f..0db3387 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -403,6 +403,7 @@ name = "poulpy-schemes" version = "0.1.1" dependencies = [ "byteorder", + "criterion", "itertools 0.14.0", "poulpy-backend", "poulpy-core", diff --git a/poulpy-backend/examples/rlwe_encrypt.rs b/poulpy-backend/examples/rlwe_encrypt.rs index 63a526a..b4338b9 100644 --- a/poulpy-backend/examples/rlwe_encrypt.rs +++ b/poulpy-backend/examples/rlwe_encrypt.rs @@ -3,7 +3,7 @@ use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_hal::{ api::{ ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal, - VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, + VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace, }, layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos}, @@ -12,10 +12,10 @@ use poulpy_hal::{ fn main() { let n: usize = 16; - let basek: usize = 18; + let base2k: usize = 18; let ct_size: usize = 3; let msg_size: usize = 2; - let log_scale: usize = msg_size * basek - 5; + let log_scale: usize = msg_size * base2k - 5; let module: Module = Module::::new(n as u64); let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes()); @@ -41,7 +41,7 @@ fn main() { ); // Fill the second column with random values: ct = (0, a) - module.vec_znx_fill_uniform(basek, &mut ct, 1, &mut source); + module.vec_znx_fill_uniform(base2k, &mut ct, 1, &mut source); let mut buf_dft: VecZnxDft, FFT64Spqlios> = module.vec_znx_dft_alloc(1, ct_size); @@ -70,11 +70,11 @@ fn main() { let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); - m.encode_vec_i64(basek, 0, log_scale, &want, 4); - module.vec_znx_normalize_inplace(basek, &mut m, 0, scratch.borrow()); + m.encode_vec_i64(base2k, 0, log_scale, &want); + module.vec_znx_normalize_inplace(base2k, &mut m, 0, scratch.borrow()); // m - BIG(ct[1] * s) - module.vec_znx_big_sub_small_b_inplace( + module.vec_znx_big_sub_small_negate_inplace( &mut buf_big, 0, // Selects the first column of the receiver &m, @@ -84,9 +84,10 @@ fn main() { // Normalizes back to VecZnx // ct[0] <- m - BIG(c1 * s) module.vec_znx_big_normalize( - basek, + base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) + base2k, &buf_big, 0, // Selects the first column of buf_big scratch.borrow(), @@ -95,10 +96,10 @@ fn main() { // Add noise to ct[0] // ct[0] <- ct[0] + e module.vec_znx_add_normal( - basek, + base2k, &mut ct, - 0, // Selects the first column of ct (ct[0]) - basek * ct_size, // Scaling of the noise: 2^{-basek * limbs} + 0, // Selects the first column of ct (ct[0]) + base2k * ct_size, // Scaling of the noise: 2^{-base2k * limbs} &mut source, 3.2, // Standard deviation 3.2 * 6.0, // Truncatation bound @@ -125,12 +126,12 @@ fn main() { // m + e <- BIG(ct[1] * s + ct[0]) let mut res = VecZnx::alloc(module.n(), 1, ct_size); - module.vec_znx_big_normalize(basek, &mut res, 0, &buf_big, 0, scratch.borrow()); + module.vec_znx_big_normalize(base2k, &mut res, 0, base2k, &buf_big, 0, scratch.borrow()); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.decode_vec_i64(basek, 0, ct_size * basek, &mut have); - let scale: f64 = (1 << (res.size() * basek - log_scale)) as f64; + res.decode_vec_i64(base2k, 0, ct_size * base2k, &mut have); + let scale: f64 = (1 << (res.size() * base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) .enumerate() .for_each(|(i, (a, b))| { diff --git a/poulpy-backend/src/cpu_fft64_avx/module.rs b/poulpy-backend/src/cpu_fft64_avx/module.rs index c90d74c..750bab1 100644 --- a/poulpy-backend/src/cpu_fft64_avx/module.rs +++ b/poulpy-backend/src/cpu_fft64_avx/module.rs @@ -7,7 +7,7 @@ use poulpy_hal::{ fft64::{ reim::{ ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul, - ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, + ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx, ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref, }, reim4::{ @@ -15,10 +15,11 @@ use poulpy_hal::{ }, }, znx::{ - ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep, + ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo, + ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, - ZnxSubABInplace, ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_rotate, znx_zero_ref, + ZnxSubInplace, ZnxSubNegateInplace, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_rotate, znx_zero_ref, }, }, }; @@ -27,8 +28,8 @@ use crate::cpu_fft64_avx::{ FFT64Avx, reim::{ ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma, - reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, - reim_sub_ab_inplace_avx2_fma, reim_sub_avx2_fma, reim_sub_ba_inplace_avx2_fma, reim_to_znx_i64_inplace_bnd63_avx2_fma, + reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, reim_sub_avx2_fma, + reim_sub_inplace_avx2_fma, reim_sub_negate_inplace_avx2_fma, reim_to_znx_i64_inplace_bnd63_avx2_fma, }, reim_to_znx_i64_bnd63_avx2_fma, reim4::{ @@ -36,11 +37,12 @@ use crate::cpu_fft64_avx::{ reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx, }, znx_avx::{ - znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_negate_avx, znx_negate_inplace_avx, - znx_normalize_final_step_avx, znx_normalize_final_step_inplace_avx, znx_normalize_first_step_avx, - znx_normalize_first_step_carry_only_avx, znx_normalize_first_step_inplace_avx, znx_normalize_middle_step_avx, - znx_normalize_middle_step_carry_only_avx, znx_normalize_middle_step_inplace_avx, znx_sub_ab_inplace_avx, znx_sub_avx, - znx_sub_ba_inplace_avx, znx_switch_ring_avx, + znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_extract_digit_addmul_avx, znx_mul_add_power_of_two_avx, + znx_mul_power_of_two_avx, znx_mul_power_of_two_inplace_avx, znx_negate_avx, znx_negate_inplace_avx, + znx_normalize_digit_avx, znx_normalize_final_step_avx, znx_normalize_final_step_inplace_avx, + znx_normalize_first_step_avx, znx_normalize_first_step_carry_only_avx, znx_normalize_first_step_inplace_avx, + znx_normalize_middle_step_avx, znx_normalize_middle_step_carry_only_avx, znx_normalize_middle_step_inplace_avx, + znx_sub_avx, znx_sub_inplace_avx, znx_sub_negate_inplace_avx, znx_switch_ring_avx, }, }; @@ -131,20 +133,20 @@ impl ZnxSub for FFT64Avx { } } -impl ZnxSubABInplace for FFT64Avx { +impl ZnxSubInplace for FFT64Avx { #[inline(always)] - fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) { + fn znx_sub_inplace(res: &mut [i64], a: &[i64]) { unsafe { - znx_sub_ab_inplace_avx(res, a); + znx_sub_inplace_avx(res, a); } } } -impl ZnxSubBAInplace for FFT64Avx { +impl ZnxSubNegateInplace for FFT64Avx { #[inline(always)] - fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) { + fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) { unsafe { - znx_sub_ba_inplace_avx(res, a); + znx_sub_negate_inplace_avx(res, a); } } } @@ -183,6 +185,33 @@ impl ZnxNegateInplace for FFT64Avx { } } +impl ZnxMulAddPowerOfTwo for FFT64Avx { + #[inline(always)] + fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) { + unsafe { + znx_mul_add_power_of_two_avx(k, res, a); + } + } +} + +impl ZnxMulPowerOfTwo for FFT64Avx { + #[inline(always)] + fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) { + unsafe { + znx_mul_power_of_two_avx(k, res, a); + } + } +} + +impl ZnxMulPowerOfTwoInplace for FFT64Avx { + #[inline(always)] + fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) { + unsafe { + znx_mul_power_of_two_inplace_avx(k, res); + } + } +} + impl ZnxRotate for FFT64Avx { #[inline(always)] fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) { @@ -208,72 +237,90 @@ impl ZnxSwitchRing for FFT64Avx { impl ZnxNormalizeFinalStep for FFT64Avx { #[inline(always)] - fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { unsafe { - znx_normalize_final_step_avx(basek, lsh, x, a, carry); + znx_normalize_final_step_avx(base2k, lsh, x, a, carry); } } } impl ZnxNormalizeFinalStepInplace for FFT64Avx { #[inline(always)] - fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { unsafe { - znx_normalize_final_step_inplace_avx(basek, lsh, x, carry); + znx_normalize_final_step_inplace_avx(base2k, lsh, x, carry); } } } impl ZnxNormalizeFirstStep for FFT64Avx { #[inline(always)] - fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { unsafe { - znx_normalize_first_step_avx(basek, lsh, x, a, carry); + znx_normalize_first_step_avx(base2k, lsh, x, a, carry); } } } impl ZnxNormalizeFirstStepCarryOnly for FFT64Avx { #[inline(always)] - fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { unsafe { - znx_normalize_first_step_carry_only_avx(basek, lsh, x, carry); + znx_normalize_first_step_carry_only_avx(base2k, lsh, x, carry); } } } impl ZnxNormalizeFirstStepInplace for FFT64Avx { #[inline(always)] - fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { unsafe { - znx_normalize_first_step_inplace_avx(basek, lsh, x, carry); + znx_normalize_first_step_inplace_avx(base2k, lsh, x, carry); } } } impl ZnxNormalizeMiddleStep for FFT64Avx { #[inline(always)] - fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { unsafe { - znx_normalize_middle_step_avx(basek, lsh, x, a, carry); + znx_normalize_middle_step_avx(base2k, lsh, x, a, carry); } } } impl ZnxNormalizeMiddleStepCarryOnly for FFT64Avx { #[inline(always)] - fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { unsafe { - znx_normalize_middle_step_carry_only_avx(basek, lsh, x, carry); + znx_normalize_middle_step_carry_only_avx(base2k, lsh, x, carry); } } } impl ZnxNormalizeMiddleStepInplace for FFT64Avx { #[inline(always)] - fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { unsafe { - znx_normalize_middle_step_inplace_avx(basek, lsh, x, carry); + znx_normalize_middle_step_inplace_avx(base2k, lsh, x, carry); + } + } +} + +impl ZnxExtractDigitAddMul for FFT64Avx { + #[inline(always)] + fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) { + unsafe { + znx_extract_digit_addmul_avx(base2k, lsh, res, src); + } + } +} + +impl ZnxNormalizeDigit for FFT64Avx { + #[inline(always)] + fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) { + unsafe { + znx_normalize_digit_avx(base2k, res, src); } } } @@ -346,20 +393,20 @@ impl ReimSub for FFT64Avx { } } -impl ReimSubABInplace for FFT64Avx { +impl ReimSubInplace for FFT64Avx { #[inline(always)] - fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) { + fn reim_sub_inplace(res: &mut [f64], a: &[f64]) { unsafe { - reim_sub_ab_inplace_avx2_fma(res, a); + reim_sub_inplace_avx2_fma(res, a); } } } -impl ReimSubBAInplace for FFT64Avx { +impl ReimSubNegateInplace for FFT64Avx { #[inline(always)] - fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) { + fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]) { unsafe { - reim_sub_ba_inplace_avx2_fma(res, a); + reim_sub_negate_inplace_avx2_fma(res, a); } } } diff --git a/poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs b/poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs index 05bd7f0..38087d6 100644 --- a/poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs +++ b/poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs @@ -88,7 +88,7 @@ pub fn reim_sub_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) { /// # Safety /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); #[target_feature(enable = "avx2,fma")] -pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) { +pub fn reim_sub_inplace_avx2_fma(res: &mut [f64], a: &[f64]) { #[cfg(debug_assertions)] { assert_eq!(a.len(), res.len()); @@ -115,7 +115,7 @@ pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) { /// # Safety /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); #[target_feature(enable = "avx2,fma")] -pub fn reim_sub_ba_inplace_avx2_fma(res: &mut [f64], a: &[f64]) { +pub fn reim_sub_negate_inplace_avx2_fma(res: &mut [f64], a: &[f64]) { #[cfg(debug_assertions)] { assert_eq!(a.len(), res.len()); diff --git a/poulpy-backend/src/cpu_fft64_avx/scratch.rs b/poulpy-backend/src/cpu_fft64_avx/scratch.rs index c3975b3..922166b 100644 --- a/poulpy-backend/src/cpu_fft64_avx/scratch.rs +++ b/poulpy-backend/src/cpu_fft64_avx/scratch.rs @@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8] (take_slice, rem_slice) } } else { - panic!( - "Attempted to take {} from scratch with {} aligned bytes left", - take_len, aligned_len, - ); + panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left"); } } diff --git a/poulpy-backend/src/cpu_fft64_avx/tests.rs b/poulpy-backend/src/cpu_fft64_avx/tests.rs index 2b4532d..d57f6c4 100644 --- a/poulpy-backend/src/cpu_fft64_avx/tests.rs +++ b/poulpy-backend/src/cpu_fft64_avx/tests.rs @@ -5,15 +5,15 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_fft64_avx::FFT64Avx, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add, test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace, test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar, test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace, test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub, - test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace, - test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace, + test_vec_znx_sub_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_inplace, + test_vec_znx_sub_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_negate_inplace, test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar, test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace, test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh, @@ -41,7 +41,7 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_fft64_avx::FFT64Avx, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft, test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace, @@ -53,20 +53,20 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_fft64_avx::FFT64Avx, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add, test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace, test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small, test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace, test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub, - test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace, + test_vec_znx_big_sub_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_inplace, test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism, test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace, test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate, test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace, test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize, - test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace, + test_vec_znx_big_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_negate_inplace, test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a, test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace, test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b, @@ -79,13 +79,13 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_fft64_avx::FFT64Avx, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add, test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace, test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub, - test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace, - test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace, + test_vec_znx_dft_sub_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_inplace, + test_vec_znx_dft_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_negate_inplace, test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply, test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume, test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa, @@ -97,7 +97,7 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_fft64_avx::FFT64Avx, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft, test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add, diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs index d61e021..33325a7 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs @@ -1,7 +1,8 @@ use poulpy_hal::{ api::{ - TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes, - VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes, + TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes, + VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxRshTmpBytes, + VecZnxSplitRingTmpBytes, }, layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, oep::{ @@ -12,7 +13,7 @@ use poulpy_hal::{ VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, - VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl, + VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, }, reference::vec_znx::{ @@ -23,7 +24,7 @@ use poulpy_hal::{ vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize, vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, - vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar, + vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar, vec_znx_sub_scalar_inplace, vec_znx_switch_ring, }, source::Source, @@ -43,9 +44,10 @@ where { fn vec_znx_normalize_impl( module: &Module, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -54,7 +56,7 @@ where A: VecZnxToRef, { let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_normalize::(basek, res, res_col, a, a_col, carry); + vec_znx_normalize::(res_basek, res, res_col, a_basek, a, a_col, carry); } } @@ -64,7 +66,7 @@ where { fn vec_znx_normalize_inplace_impl( module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch, @@ -72,7 +74,7 @@ where R: VecZnxToMut, { let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_normalize_inplace::(basek, res, res_col, carry); + vec_znx_normalize_inplace::(base2k, res, res_col, carry); } } @@ -143,23 +145,23 @@ unsafe impl VecZnxSubImpl for FFT64Avx { } } -unsafe impl VecZnxSubABInplaceImpl for FFT64Avx { - fn vec_znx_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxSubInplaceImpl for FFT64Avx { + fn vec_znx_sub_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, { - vec_znx_sub_ab_inplace::(res, res_col, a, a_col); + vec_znx_sub_inplace::(res, res_col, a, a_col); } } -unsafe impl VecZnxSubBAInplaceImpl for FFT64Avx { - fn vec_znx_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxSubNegateInplaceImpl for FFT64Avx { + fn vec_znx_sub_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, { - vec_znx_sub_ba_inplace::(res, res_col, a, a_col); + vec_znx_sub_negate_inplace::(res, res_col, a, a_col); } } @@ -234,9 +236,9 @@ where Module: VecZnxNormalizeTmpBytes, Scratch: TakeSlice, { - fn vec_znx_lsh_inplace_impl( + fn vec_znx_lsh_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -247,8 +249,8 @@ where R: VecZnxToMut, A: VecZnxToRef, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry); + let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::()); + vec_znx_lsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry); } } @@ -259,7 +261,7 @@ where { fn vec_znx_lsh_inplace_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, a: &mut A, a_col: usize, @@ -267,8 +269,8 @@ where ) where A: VecZnxToMut, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry); + let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::()); + vec_znx_lsh_inplace::<_, Self>(base2k, k, a, a_col, carry); } } @@ -277,9 +279,9 @@ where Module: VecZnxNormalizeTmpBytes, Scratch: TakeSlice, { - fn vec_znx_rsh_inplace_impl( + fn vec_znx_rsh_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -290,8 +292,8 @@ where R: VecZnxToMut, A: VecZnxToRef, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry); + let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::()); + vec_znx_rsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry); } } @@ -302,7 +304,7 @@ where { fn vec_znx_rsh_inplace_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, a: &mut A, a_col: usize, @@ -310,8 +312,8 @@ where ) where A: VecZnxToMut, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry); + let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::()); + vec_znx_rsh_inplace::<_, Self>(base2k, k, a, a_col, carry); } } @@ -495,18 +497,18 @@ unsafe impl VecZnxCopyImpl for FFT64Avx { } unsafe impl VecZnxFillUniformImpl for FFT64Avx { - fn vec_znx_fill_uniform_impl(_module: &Module, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn vec_znx_fill_uniform_impl(_module: &Module, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut, { - vec_znx_fill_uniform_ref(basek, res, res_col, source) + vec_znx_fill_uniform_ref(base2k, res, res_col, source) } } unsafe impl VecZnxFillNormalImpl for FFT64Avx { fn vec_znx_fill_normal_impl( _module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -516,14 +518,14 @@ unsafe impl VecZnxFillNormalImpl for FFT64Avx { ) where R: VecZnxToMut, { - vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source); + vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source); } } unsafe impl VecZnxAddNormalImpl for FFT64Avx { fn vec_znx_add_normal_impl( _module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -533,6 +535,6 @@ unsafe impl VecZnxAddNormalImpl for FFT64Avx { ) where R: VecZnxToMut, { - vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source); + vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source); } } diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs index f8d9180..99a39fd 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs @@ -10,15 +10,15 @@ use poulpy_hal::{ VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, - VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, - VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl, + VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl, }, reference::{ fft64::vec_znx_big::{ vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small, vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace, vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize, - vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace, + vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_inplace, vec_znx_big_sub_negate_inplace, vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace, }, znx::{znx_copy_ref, znx_zero_ref}, @@ -76,7 +76,7 @@ unsafe impl VecZnxBigFromSmallImpl for FFT64Avx { unsafe impl VecZnxBigAddNormalImpl for FFT64Avx { fn add_normal_impl>( _module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -84,7 +84,7 @@ unsafe impl VecZnxBigAddNormalImpl for FFT64Avx { sigma: f64, bound: f64, ) { - vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source); + vec_znx_big_add_normal_ref(base2k, res, res_col, k, sigma, bound, source); } } @@ -167,25 +167,25 @@ unsafe impl VecZnxBigSubImpl for FFT64Avx { } } -unsafe impl VecZnxBigSubABInplaceImpl for FFT64Avx { +unsafe impl VecZnxBigSubInplaceImpl for FFT64Avx { /// Subtracts `a` from `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef, { - vec_znx_big_sub_ab_inplace(res, res_col, a, a_col); + vec_znx_big_sub_inplace(res, res_col, a, a_col); } } -unsafe impl VecZnxBigSubBAInplaceImpl for FFT64Avx { +unsafe impl VecZnxBigSubNegateInplaceImpl for FFT64Avx { /// Subtracts `b` from `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef, { - vec_znx_big_sub_ba_inplace(res, res_col, a, a_col); + vec_znx_big_sub_negate_inplace(res, res_col, a, a_col); } } @@ -208,9 +208,9 @@ unsafe impl VecZnxBigSubSmallAImpl for FFT64Avx { } } -unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64Avx { +unsafe impl VecZnxBigSubSmallInplaceImpl for FFT64Avx { /// Subtracts `a` from `res` and stores the result on `res`. - fn vec_znx_big_sub_small_a_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef, @@ -238,9 +238,9 @@ unsafe impl VecZnxBigSubSmallBImpl for FFT64Avx { } } -unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64Avx { +unsafe impl VecZnxBigSubSmallNegateInplaceImpl for FFT64Avx { /// Subtracts `res` from `a` and stores the result on `res`. - fn vec_znx_big_sub_small_b_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef, @@ -280,9 +280,10 @@ where { fn vec_znx_big_normalize_impl( module: &Module, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -291,7 +292,7 @@ where A: VecZnxBigToRef, { let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); - vec_znx_big_normalize(basek, res, res_col, a, a_col, carry); + vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry); } } @@ -326,7 +327,7 @@ where ) where R: VecZnxBigToMut, { - let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); + let (tmp, _) = scratch.take_slice(module.vec_znx_big_automorphism_inplace_tmp_bytes() / size_of::()); vec_znx_big_automorphism_inplace(p, res, res_col, tmp); } } diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs index bca555e..862f623 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs @@ -5,12 +5,12 @@ use poulpy_hal::{ }, oep::{ VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, - VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, }, reference::fft64::vec_znx_dft::{ - vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, - vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume, + vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace, + vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume, vec_znx_idft_apply_tmpa, }, }; @@ -139,23 +139,23 @@ unsafe impl VecZnxDftSubImpl for FFT64Avx { } } -unsafe impl VecZnxDftSubABInplaceImpl for FFT64Avx { - fn vec_znx_dft_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxDftSubInplaceImpl for FFT64Avx { + fn vec_znx_dft_sub_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, { - vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col); + vec_znx_dft_sub_inplace(res, res_col, a, a_col); } } -unsafe impl VecZnxDftSubBAInplaceImpl for FFT64Avx { - fn vec_znx_dft_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxDftSubNegateInplaceImpl for FFT64Avx { + fn vec_znx_dft_sub_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, { - vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col); + vec_znx_dft_sub_negate_inplace(res, res_col, a, a_col); } } diff --git a/poulpy-backend/src/cpu_fft64_avx/zn.rs b/poulpy-backend/src/cpu_fft64_avx/zn.rs index 033c3e2..53ce1c9 100644 --- a/poulpy-backend/src/cpu_fft64_avx/zn.rs +++ b/poulpy-backend/src/cpu_fft64_avx/zn.rs @@ -18,21 +18,21 @@ unsafe impl ZnNormalizeInplaceImpl for FFT64Avx where Self: TakeSliceImpl, { - fn zn_normalize_inplace_impl(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) + fn zn_normalize_inplace_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) where R: ZnToMut, { let (carry, _) = scratch.take_slice(n); - zn_normalize_inplace::(n, basek, res, res_col, carry); + zn_normalize_inplace::(n, base2k, res, res_col, carry); } } unsafe impl ZnFillUniformImpl for FFT64Avx { - fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut, { - zn_fill_uniform(n, basek, res, res_col, source); + zn_fill_uniform(n, base2k, res, res_col, source); } } @@ -40,7 +40,7 @@ unsafe impl ZnFillNormalImpl for FFT64Avx { #[allow(clippy::too_many_arguments)] fn zn_fill_normal_impl( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -50,7 +50,7 @@ unsafe impl ZnFillNormalImpl for FFT64Avx { ) where R: ZnToMut, { - zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound); + zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound); } } @@ -58,7 +58,7 @@ unsafe impl ZnAddNormalImpl for FFT64Avx { #[allow(clippy::too_many_arguments)] fn zn_add_normal_impl( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -68,6 +68,6 @@ unsafe impl ZnAddNormalImpl for FFT64Avx { ) where R: ZnToMut, { - zn_add_normal(n, basek, res, res_col, k, source, sigma, bound); + zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound); } } diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/mod.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/mod.rs index 70ab48b..1845fed 100644 --- a/poulpy-backend/src/cpu_fft64_avx/znx_avx/mod.rs +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/mod.rs @@ -1,5 +1,6 @@ mod add; mod automorphism; +mod mul; mod neg; mod normalization; mod sub; @@ -7,6 +8,7 @@ mod switch_ring; pub(crate) use add::*; pub(crate) use automorphism::*; +pub(crate) use mul::*; pub(crate) use neg::*; pub(crate) use normalization::*; pub(crate) use sub::*; diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/mul.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/mul.rs new file mode 100644 index 0000000..44dfdb6 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/mul.rs @@ -0,0 +1,318 @@ +/// Multiply/divide by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_ref]. +/// +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub unsafe fn znx_mul_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + use core::arch::x86_64::{ + __m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256, + _mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64, + _mm256_storeu_si256, _mm256_sub_epi64, + }; + + let n: usize = res.len(); + + if n == 0 { + return; + } + + if k == 0 { + use poulpy_hal::reference::znx::znx_copy_ref; + znx_copy_ref(res, a); + return; + } + + let span: usize = n >> 2; // number of 256-bit chunks + + unsafe { + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut aa: *const __m256i = a.as_ptr() as *const __m256i; + + if k > 0 { + // Left shift by k (variable count). + #[cfg(debug_assertions)] + { + debug_assert!(k <= 63); + } + let cnt128: __m128i = _mm_cvtsi32_si128(k as i32); + for _ in 0..span { + let x: __m256i = _mm256_loadu_si256(aa); + let y: __m256i = _mm256_sll_epi64(x, cnt128); + _mm256_storeu_si256(rr, y); + rr = rr.add(1); + aa = aa.add(1); + } + + // tail + if !n.is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_mul_power_of_two_ref; + + znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]); + } + return; + } + + // k < 0 => arithmetic right shift with rounding: + // for each x: + // sign_bit = (x >> 63) & 1 + // bias = (1<<(kp-1)) - sign_bit + // t = x + bias + // y = t >> kp (arithmetic) + let kp = -k; + #[cfg(debug_assertions)] + { + debug_assert!((1..=63).contains(&kp)); + } + + let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32); + let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1)); + let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits + let zero: __m256i = _mm256_setzero_si256(); + + for _ in 0..span { + let x = _mm256_loadu_si256(aa); + + // bias = (1 << (kp-1)) - sign_bit + let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63); + let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x); + + // t = x + bias + let t: __m256i = _mm256_add_epi64(x, bias); + + // logical shift + let lsr: __m256i = _mm256_srl_epi64(t, cnt_right); + + // sign extension + let neg: __m256i = _mm256_cmpgt_epi64(zero, t); + let fill: __m256i = _mm256_and_si256(neg, top_mask); + let y: __m256i = _mm256_or_si256(lsr, fill); + + _mm256_storeu_si256(rr, y); + rr = rr.add(1); + aa = aa.add(1); + } + } + + // tail + if !n.is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_mul_power_of_two_ref; + + znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]); + } +} + +/// Multiply/divide inplace by a power of two with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref]. +/// +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub unsafe fn znx_mul_power_of_two_inplace_avx(k: i64, res: &mut [i64]) { + use core::arch::x86_64::{ + __m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256, + _mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64, + _mm256_storeu_si256, _mm256_sub_epi64, + }; + + let n: usize = res.len(); + + if n == 0 { + return; + } + + if k == 0 { + return; + } + + let span: usize = n >> 2; // number of 256-bit chunks + + unsafe { + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + + if k > 0 { + // Left shift by k (variable count). + #[cfg(debug_assertions)] + { + debug_assert!(k <= 63); + } + let cnt128: __m128i = _mm_cvtsi32_si128(k as i32); + for _ in 0..span { + let x: __m256i = _mm256_loadu_si256(rr); + let y: __m256i = _mm256_sll_epi64(x, cnt128); + _mm256_storeu_si256(rr, y); + rr = rr.add(1); + } + + // tail + if !n.is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref; + znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]); + } + return; + } + + // k < 0 => arithmetic right shift with rounding: + // for each x: + // sign_bit = (x >> 63) & 1 + // bias = (1<<(kp-1)) - sign_bit + // t = x + bias + // y = t >> kp (arithmetic) + let kp = -k; + #[cfg(debug_assertions)] + { + debug_assert!((1..=63).contains(&kp)); + } + + let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32); + let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1)); + let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits + let zero: __m256i = _mm256_setzero_si256(); + + for _ in 0..span { + let x = _mm256_loadu_si256(rr); + + // bias = (1 << (kp-1)) - sign_bit + let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63); + let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x); + + // t = x + bias + let t: __m256i = _mm256_add_epi64(x, bias); + + // logical shift + let lsr: __m256i = _mm256_srl_epi64(t, cnt_right); + + // sign extension + let neg: __m256i = _mm256_cmpgt_epi64(zero, t); + let fill: __m256i = _mm256_and_si256(neg, top_mask); + let y: __m256i = _mm256_or_si256(lsr, fill); + + _mm256_storeu_si256(rr, y); + rr = rr.add(1); + } + } + + // tail + if !n.is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref; + znx_mul_power_of_two_inplace_ref(k, &mut res[span << 2..]); + } +} + +/// Multiply/divide by a power of two and add on the result with rounding matching [poulpy_hal::reference::znx::znx_mul_power_of_two_inplace_ref]. +/// +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub unsafe fn znx_mul_add_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + use core::arch::x86_64::{ + __m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256, + _mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64, + _mm256_storeu_si256, _mm256_sub_epi64, + }; + + let n: usize = res.len(); + + if n == 0 { + return; + } + + if k == 0 { + use crate::cpu_fft64_avx::znx_avx::znx_add_inplace_avx; + + znx_add_inplace_avx(res, a); + return; + } + + let span: usize = n >> 2; // number of 256-bit chunks + + unsafe { + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut aa: *const __m256i = a.as_ptr() as *const __m256i; + + if k > 0 { + // Left shift by k (variable count). + #[cfg(debug_assertions)] + { + debug_assert!(k <= 63); + } + let cnt128: __m128i = _mm_cvtsi32_si128(k as i32); + for _ in 0..span { + let x: __m256i = _mm256_loadu_si256(aa); + let y: __m256i = _mm256_loadu_si256(rr); + _mm256_storeu_si256(rr, _mm256_add_epi64(y, _mm256_sll_epi64(x, cnt128))); + rr = rr.add(1); + aa = aa.add(1); + } + + // tail + if !n.is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref; + + znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]); + } + return; + } + + // k < 0 => arithmetic right shift with rounding: + // for each x: + // sign_bit = (x >> 63) & 1 + // bias = (1<<(kp-1)) - sign_bit + // t = x + bias + // y = t >> kp (arithmetic) + let kp = -k; + #[cfg(debug_assertions)] + { + debug_assert!((1..=63).contains(&kp)); + } + + let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32); + let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1)); + let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); // high kp bits + let zero: __m256i = _mm256_setzero_si256(); + + for _ in 0..span { + let x: __m256i = _mm256_loadu_si256(aa); + let y: __m256i = _mm256_loadu_si256(rr); + + // bias = (1 << (kp-1)) - sign_bit + let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63); + let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x); + + // t = x + bias + let t: __m256i = _mm256_add_epi64(x, bias); + + // logical shift + let lsr: __m256i = _mm256_srl_epi64(t, cnt_right); + + // sign extension + let neg: __m256i = _mm256_cmpgt_epi64(zero, t); + let fill: __m256i = _mm256_and_si256(neg, top_mask); + let out: __m256i = _mm256_or_si256(lsr, fill); + + _mm256_storeu_si256(rr, _mm256_add_epi64(y, out)); + rr = rr.add(1); + aa = aa.add(1); + } + } + + // tail + if !n.is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_mul_add_power_of_two_ref; + znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs index 54b2014..bcf8f7d 100644 --- a/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs @@ -6,14 +6,14 @@ use std::arch::x86_64::__m256i; /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -fn normalize_consts_avx(basek: usize) -> (__m256i, __m256i, __m256i, __m256i) { +fn normalize_consts_avx(base2k: usize) -> (__m256i, __m256i, __m256i, __m256i) { use std::arch::x86_64::_mm256_set1_epi64x; - assert!((1..=63).contains(&basek)); - let mask_k: i64 = ((1u64 << basek) - 1) as i64; // 0..k-1 bits set - let sign_k: i64 = (1u64 << (basek - 1)) as i64; // bit k-1 - let topmask: i64 = (!0u64 << (64 - basek)) as i64; // top k bits set - let sh_k: __m256i = _mm256_set1_epi64x(basek as i64); + assert!((1..=63).contains(&base2k)); + let mask_k: i64 = ((1u64 << base2k) - 1) as i64; // 0..k-1 bits set + let sign_k: i64 = (1u64 << (base2k - 1)) as i64; // bit k-1 + let topmask: i64 = (!0u64 << (64 - base2k)) as i64; // top k bits set + let sh_k: __m256i = _mm256_set1_epi64x(base2k as i64); ( _mm256_set1_epi64x(mask_k), // mask_k_vec _mm256_set1_epi64x(sign_k), // sign_k_vec @@ -46,14 +46,14 @@ fn get_digit_avx(x: __m256i, mask_k: __m256i, sign_k: __m256i) -> __m256i { unsafe fn get_carry_avx( x: __m256i, digit: __m256i, - basek: __m256i, // _mm256_set1_epi64x(k) + base2k: __m256i, // _mm256_set1_epi64x(k) top_mask: __m256i, // (!0 << (64 - k)) broadcast ) -> __m256i { use std::arch::x86_64::{ __m256i, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_or_si256, _mm256_setzero_si256, _mm256_srlv_epi64, _mm256_sub_epi64, }; let diff: __m256i = _mm256_sub_epi64(x, digit); - let lsr: __m256i = _mm256_srlv_epi64(diff, basek); // logical >> + let lsr: __m256i = _mm256_srlv_epi64(diff, base2k); // logical >> let neg: __m256i = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); // 0xFFFF.. where v<0 let fill: __m256i = _mm256_and_si256(neg, top_mask); // top k bits if negative _mm256_or_si256(lsr, fill) @@ -61,13 +61,121 @@ unsafe fn get_carry_avx( /// # Safety /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); -/// all inputs must have the same length and must not alias. +/// `res` and `src` must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] -pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { +pub fn znx_extract_digit_addmul_avx(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) { #[cfg(debug_assertions)] { - assert_eq!(x.len(), carry.len()); - assert!(lsh < basek); + assert_eq!(res.len(), src.len()); + assert!(lsh < base2k); + } + + use std::arch::x86_64::{ + __m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256, + }; + + let n: usize = res.len(); + let span: usize = n >> 2; + + unsafe { + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i; + + // constants for digit/carry extraction + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); + let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); + + for _ in 0..span { + // load source & extract digit/carry + let sv: __m256i = _mm256_loadu_si256(ss); + let digit_256: __m256i = get_digit_avx(sv, mask, sign); + let carry_256: __m256i = get_carry_avx(sv, digit_256, basek_vec, top_mask); + + // res += (digit << lsh) + let rv: __m256i = _mm256_loadu_si256(rr); + let madd: __m256i = _mm256_sllv_epi64(digit_256, lsh_v); + let sum: __m256i = _mm256_add_epi64(rv, madd); + + _mm256_storeu_si256(rr, sum); + _mm256_storeu_si256(ss, carry_256); + + rr = rr.add(1); + ss = ss.add(1); + } + } + + // tail (scalar) + if !n.is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_extract_digit_addmul_ref; + + let off: usize = span << 2; + znx_extract_digit_addmul_ref(base2k, lsh, &mut res[off..], &mut src[off..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// `res` and `src` must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_normalize_digit_avx(base2k: usize, res: &mut [i64], src: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), src.len()); + } + + use std::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256}; + + let n: usize = res.len(); + let span: usize = n >> 2; + + unsafe { + // Pointers to 256-bit lanes + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut ss: *mut __m256i = src.as_mut_ptr() as *mut __m256i; + + // Constants for digit/carry extraction + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); + + for _ in 0..span { + // Load res lane + let rv: __m256i = _mm256_loadu_si256(rr); + + // Extract digit and carry from res + let digit_256: __m256i = get_digit_avx(rv, mask, sign); + let carry_256: __m256i = get_carry_avx(rv, digit_256, basek_vec, top_mask); + + // src += carry + let sv: __m256i = _mm256_loadu_si256(ss); + let sum: __m256i = _mm256_add_epi64(sv, carry_256); + + _mm256_storeu_si256(ss, sum); + _mm256_storeu_si256(rr, digit_256); + + rr = rr.add(1); + ss = ss.add(1); + } + } + + // scalar tail + if !n.is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_normalize_digit_ref; + + let off = span << 2; + znx_normalize_digit_ref(base2k, &mut res[off..], &mut src[off..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[target_feature(enable = "avx2")] +pub fn znx_normalize_first_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert!(x.len() <= carry.len()); + assert!(lsh < base2k); } use std::arch::x86_64::{_mm256_loadu_si256, _mm256_storeu_si256}; @@ -81,19 +189,19 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6 let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; let (mask, sign, basek_vec, top_mask) = if lsh == 0 { - normalize_consts_avx(basek) + normalize_consts_avx(base2k) } else { - normalize_consts_avx(basek - lsh) + normalize_consts_avx(base2k - lsh) }; for _ in 0..span { - let xx_256: __m256i = _mm256_loadu_si256(xx); + let xv: __m256i = _mm256_loadu_si256(xx); - // (x << (64 - basek)) >> (64 - basek) - let digit_256: __m256i = get_digit_avx(xx_256, mask, sign); + // (x << (64 - base2k)) >> (64 - base2k) + let digit_256: __m256i = get_digit_avx(xv, mask, sign); - // (x - digit) >> basek - let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask); + // (x - digit) >> base2k + let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask); _mm256_storeu_si256(cc, carry_256); @@ -106,7 +214,7 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6 if !x.len().is_multiple_of(4) { use poulpy_hal::reference::znx::znx_normalize_first_step_carry_only_ref; - znx_normalize_first_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]); + znx_normalize_first_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]); } } @@ -114,11 +222,11 @@ pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i6 /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { +pub fn znx_normalize_first_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { - assert_eq!(x.len(), carry.len()); - assert!(lsh < basek); + assert!(x.len() <= carry.len()); + assert!(lsh < base2k); } use std::arch::x86_64::{_mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256}; @@ -132,16 +240,16 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; if lsh == 0 { - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); for _ in 0..span { - let xx_256: __m256i = _mm256_loadu_si256(xx); + let xv: __m256i = _mm256_loadu_si256(xx); - // (x << (64 - basek)) >> (64 - basek) - let digit_256: __m256i = get_digit_avx(xx_256, mask, sign); + // (x << (64 - base2k)) >> (64 - base2k) + let digit_256: __m256i = get_digit_avx(xv, mask, sign); - // (x - digit) >> basek - let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask); + // (x - digit) >> base2k + let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask); _mm256_storeu_si256(xx, digit_256); _mm256_storeu_si256(cc, carry_256); @@ -150,18 +258,18 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i cc = cc.add(1); } } else { - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh); + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); for _ in 0..span { - let xx_256: __m256i = _mm256_loadu_si256(xx); + let xv: __m256i = _mm256_loadu_si256(xx); - // (x << (64 - basek)) >> (64 - basek) - let digit_256: __m256i = get_digit_avx(xx_256, mask, sign); + // (x << (64 - base2k)) >> (64 - base2k) + let digit_256: __m256i = get_digit_avx(xv, mask, sign); - // (x - digit) >> basek - let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask); + // (x - digit) >> base2k + let carry_256: __m256i = get_carry_avx(xv, digit_256, basek_vec, top_mask); _mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v)); _mm256_storeu_si256(cc, carry_256); @@ -176,7 +284,7 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i if !x.len().is_multiple_of(4) { use poulpy_hal::reference::znx::znx_normalize_first_step_inplace_ref; - znx_normalize_first_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]); + znx_normalize_first_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]); } } @@ -184,12 +292,12 @@ pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { +pub fn znx_normalize_first_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { - assert_eq!(x.len(), carry.len()); - assert_eq!(a.len(), carry.len()); - assert!(lsh < basek); + assert_eq!(x.len(), a.len()); + assert!(x.len() <= carry.len()); + assert!(lsh < base2k); } use std::arch::x86_64::{_mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; @@ -204,16 +312,16 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; if lsh == 0 { - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); for _ in 0..span { - let aa_256: __m256i = _mm256_loadu_si256(aa); + let av: __m256i = _mm256_loadu_si256(aa); - // (x << (64 - basek)) >> (64 - basek) - let digit_256: __m256i = get_digit_avx(aa_256, mask, sign); + // (x << (64 - base2k)) >> (64 - base2k) + let digit_256: __m256i = get_digit_avx(av, mask, sign); - // (x - digit) >> basek - let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask); + // (x - digit) >> base2k + let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask); _mm256_storeu_si256(xx, digit_256); _mm256_storeu_si256(cc, carry_256); @@ -225,18 +333,18 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh); + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); for _ in 0..span { - let aa_256: __m256i = _mm256_loadu_si256(aa); + let av: __m256i = _mm256_loadu_si256(aa); - // (x << (64 - basek)) >> (64 - basek) - let digit_256: __m256i = get_digit_avx(aa_256, mask, sign); + // (x << (64 - base2k)) >> (64 - base2k) + let digit_256: __m256i = get_digit_avx(av, mask, sign); - // (x - digit) >> basek - let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask); + // (x - digit) >> base2k + let carry_256: __m256i = get_carry_avx(av, digit_256, basek_vec, top_mask); _mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v)); _mm256_storeu_si256(cc, carry_256); @@ -253,7 +361,7 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: use poulpy_hal::reference::znx::znx_normalize_first_step_ref; znx_normalize_first_step_ref( - basek, + base2k, lsh, &mut x[span << 2..], &a[span << 2..], @@ -266,11 +374,11 @@ pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { +pub fn znx_normalize_middle_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { - assert_eq!(x.len(), carry.len()); - assert!(lsh < basek); + assert!(x.len() <= carry.len()); + assert!(lsh < base2k); } use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; @@ -279,7 +387,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [ let span: usize = n >> 2; - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); unsafe { let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; @@ -287,13 +395,13 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [ if lsh == 0 { for _ in 0..span { - let xx_256: __m256i = _mm256_loadu_si256(xx); - let cc_256: __m256i = _mm256_loadu_si256(cc); + let xv: __m256i = _mm256_loadu_si256(xx); + let cv: __m256i = _mm256_loadu_si256(cc); - let d0: __m256i = get_digit_avx(xx_256, mask, sign); - let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask); + let d0: __m256i = get_digit_avx(xv, mask, sign); + let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask); - let s: __m256i = _mm256_add_epi64(d0, cc_256); + let s: __m256i = _mm256_add_epi64(d0, cv); let x1: __m256i = get_digit_avx(s, mask, sign); let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); @@ -307,20 +415,20 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [ } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh); + let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); for _ in 0..span { - let xx_256: __m256i = _mm256_loadu_si256(xx); - let cc_256: __m256i = _mm256_loadu_si256(cc); + let xv: __m256i = _mm256_loadu_si256(xx); + let cv: __m256i = _mm256_loadu_si256(cc); - let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh); - let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh); + let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh); + let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh); let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); - let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256); + let s: __m256i = _mm256_add_epi64(d0_lsh, cv); let x1: __m256i = get_digit_avx(s, mask, sign); let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); @@ -337,7 +445,7 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [ if !x.len().is_multiple_of(4) { use poulpy_hal::reference::znx::znx_normalize_middle_step_inplace_ref; - znx_normalize_middle_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]); + znx_normalize_middle_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]); } } @@ -345,11 +453,11 @@ pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [ /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { +pub fn znx_normalize_middle_step_carry_only_avx(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { - assert_eq!(x.len(), carry.len()); - assert!(lsh < basek); + assert!(x.len() <= carry.len()); + assert!(lsh < base2k); } use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; @@ -358,7 +466,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i let span: usize = n >> 2; - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); unsafe { let mut xx: *const __m256i = x.as_ptr() as *const __m256i; @@ -366,13 +474,13 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i if lsh == 0 { for _ in 0..span { - let xx_256: __m256i = _mm256_loadu_si256(xx); - let cc_256: __m256i = _mm256_loadu_si256(cc); + let xv: __m256i = _mm256_loadu_si256(xx); + let cv: __m256i = _mm256_loadu_si256(cc); - let d0: __m256i = get_digit_avx(xx_256, mask, sign); - let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask); + let d0: __m256i = get_digit_avx(xv, mask, sign); + let c0: __m256i = get_carry_avx(xv, d0, basek_vec, top_mask); - let s: __m256i = _mm256_add_epi64(d0, cc_256); + let s: __m256i = _mm256_add_epi64(d0, cv); let x1: __m256i = get_digit_avx(s, mask, sign); let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); @@ -385,20 +493,20 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh); + let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); for _ in 0..span { - let xx_256: __m256i = _mm256_loadu_si256(xx); - let cc_256: __m256i = _mm256_loadu_si256(cc); + let xv: __m256i = _mm256_loadu_si256(xx); + let cv: __m256i = _mm256_loadu_si256(cc); - let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh); - let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh); + let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh); + let c0: __m256i = get_carry_avx(xv, d0, basek_vec_lsh, top_mask_lsh); let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); - let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256); + let s: __m256i = _mm256_add_epi64(d0_lsh, cv); let x1: __m256i = get_digit_avx(s, mask, sign); let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); @@ -414,7 +522,7 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i if !x.len().is_multiple_of(4) { use poulpy_hal::reference::znx::znx_normalize_middle_step_carry_only_ref; - znx_normalize_middle_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]); + znx_normalize_middle_step_carry_only_ref(base2k, lsh, &x[span << 2..], &mut carry[span << 2..]); } } @@ -422,12 +530,12 @@ pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { +pub fn znx_normalize_middle_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { - assert_eq!(x.len(), carry.len()); - assert_eq!(a.len(), carry.len()); - assert!(lsh < basek); + assert_eq!(x.len(), a.len()); + assert!(x.len() <= carry.len()); + assert!(lsh < base2k); } use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; @@ -436,7 +544,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: let span: usize = n >> 2; - let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(base2k); unsafe { let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; @@ -445,13 +553,13 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: if lsh == 0 { for _ in 0..span { - let aa_256: __m256i = _mm256_loadu_si256(aa); - let cc_256: __m256i = _mm256_loadu_si256(cc); + let av: __m256i = _mm256_loadu_si256(aa); + let cv: __m256i = _mm256_loadu_si256(cc); - let d0: __m256i = get_digit_avx(aa_256, mask, sign); - let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec, top_mask); + let d0: __m256i = get_digit_avx(av, mask, sign); + let c0: __m256i = get_carry_avx(av, d0, basek_vec, top_mask); - let s: __m256i = _mm256_add_epi64(d0, cc_256); + let s: __m256i = _mm256_add_epi64(d0, cv); let x1: __m256i = get_digit_avx(s, mask, sign); let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); @@ -466,20 +574,20 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh); + let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); for _ in 0..span { - let aa_256: __m256i = _mm256_loadu_si256(aa); - let cc_256: __m256i = _mm256_loadu_si256(cc); + let av: __m256i = _mm256_loadu_si256(aa); + let cv: __m256i = _mm256_loadu_si256(cc); - let d0: __m256i = get_digit_avx(aa_256, mask_lsh, sign_lsh); - let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec_lsh, top_mask_lsh); + let d0: __m256i = get_digit_avx(av, mask_lsh, sign_lsh); + let c0: __m256i = get_carry_avx(av, d0, basek_vec_lsh, top_mask_lsh); let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); - let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256); + let s: __m256i = _mm256_add_epi64(d0_lsh, cv); let x1: __m256i = get_digit_avx(s, mask, sign); let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); let cout: __m256i = _mm256_add_epi64(c0, c1); @@ -498,7 +606,7 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: use poulpy_hal::reference::znx::znx_normalize_middle_step_ref; znx_normalize_middle_step_ref( - basek, + base2k, lsh, &mut x[span << 2..], &a[span << 2..], @@ -511,11 +619,11 @@ pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { +pub fn znx_normalize_final_step_inplace_avx(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { - assert_eq!(x.len(), carry.len()); - assert!(lsh < basek); + assert!(x.len() <= carry.len()); + assert!(lsh < base2k); } use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; @@ -524,7 +632,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i let span: usize = n >> 2; - let (mask, sign, _, _) = normalize_consts_avx(basek); + let (mask, sign, _, _) = normalize_consts_avx(base2k); unsafe { let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; @@ -547,7 +655,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh); + let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); @@ -573,7 +681,7 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i if !x.len().is_multiple_of(4) { use poulpy_hal::reference::znx::znx_normalize_final_step_inplace_ref; - znx_normalize_final_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]); + znx_normalize_final_step_inplace_ref(base2k, lsh, &mut x[span << 2..], &mut carry[span << 2..]); } } @@ -581,12 +689,12 @@ pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { +pub fn znx_normalize_final_step_avx(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { - assert_eq!(x.len(), carry.len()); - assert_eq!(a.len(), carry.len()); - assert!(lsh < basek); + assert_eq!(x.len(), a.len()); + assert!(x.len() <= carry.len()); + assert!(lsh < base2k); } use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; @@ -595,7 +703,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: let span: usize = n >> 2; - let (mask, sign, _, _) = normalize_consts_avx(basek); + let (mask, sign, _, _) = normalize_consts_avx(base2k); unsafe { let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; @@ -620,7 +728,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: } else { use std::arch::x86_64::_mm256_set1_epi64x; - let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh); + let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(base2k - lsh); let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); @@ -647,7 +755,7 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: use poulpy_hal::reference::znx::znx_normalize_final_step_ref; znx_normalize_final_step_ref( - basek, + base2k, lsh, &mut x[span << 2..], &a[span << 2..], @@ -658,9 +766,9 @@ pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: mod tests { use poulpy_hal::reference::znx::{ - get_carry, get_digit, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, - znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref, - znx_normalize_middle_step_ref, + get_carry_i64, get_digit_i64, znx_extract_digit_addmul_ref, znx_normalize_digit_ref, + znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_inplace_ref, + znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, }; use super::*; @@ -670,7 +778,7 @@ mod tests { #[allow(dead_code)] #[target_feature(enable = "avx2")] fn test_get_digit_avx_internal() { - let basek: usize = 12; + let base2k: usize = 12; let x: [i64; 4] = [ 7638646372408325293, -61440197422348985, @@ -678,15 +786,15 @@ mod tests { -4835376105455195188, ]; let y0: Vec = vec![ - get_digit(basek, x[0]), - get_digit(basek, x[1]), - get_digit(basek, x[2]), - get_digit(basek, x[3]), + get_digit_i64(base2k, x[0]), + get_digit_i64(base2k, x[1]), + get_digit_i64(base2k, x[2]), + get_digit_i64(base2k, x[3]), ]; let mut y1: Vec = vec![0i64; 4]; unsafe { let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i); - let (mask, sign, _, _) = normalize_consts_avx(basek); + let (mask, sign, _, _) = normalize_consts_avx(base2k); let digit: __m256i = get_digit_avx(x_256, mask, sign); _mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit); } @@ -707,7 +815,7 @@ mod tests { #[allow(dead_code)] #[target_feature(enable = "avx2")] fn test_get_carry_avx_internal() { - let basek: usize = 12; + let base2k: usize = 12; let x: [i64; 4] = [ 7638646372408325293, -61440197422348985, @@ -716,16 +824,16 @@ mod tests { ]; let carry: [i64; 4] = [1174467039, -144794816, -1466676977, 513122840]; let y0: Vec = vec![ - get_carry(basek, x[0], carry[0]), - get_carry(basek, x[1], carry[1]), - get_carry(basek, x[2], carry[2]), - get_carry(basek, x[3], carry[3]), + get_carry_i64(base2k, x[0], carry[0]), + get_carry_i64(base2k, x[1], carry[1]), + get_carry_i64(base2k, x[2], carry[2]), + get_carry_i64(base2k, x[3], carry[3]), ]; let mut y1: Vec = vec![0i64; 4]; unsafe { let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i); let d_256: __m256i = _mm256_loadu_si256(carry.as_ptr() as *const __m256i); - let (_, _, basek_vec, top_mask) = normalize_consts_avx(basek); + let (_, _, basek_vec, top_mask) = normalize_consts_avx(base2k); let digit: __m256i = get_carry_avx(x_256, d_256, basek_vec, top_mask); _mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit); } @@ -762,16 +870,16 @@ mod tests { ]; let mut c1: [i64; 4] = c0; - let basek = 12; + let base2k = 12; - znx_normalize_first_step_inplace_ref(basek, 0, &mut y0, &mut c0); - znx_normalize_first_step_inplace_avx(basek, 0, &mut y1, &mut c1); + znx_normalize_first_step_inplace_ref(base2k, 0, &mut y0, &mut c0); + znx_normalize_first_step_inplace_avx(base2k, 0, &mut y1, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); - znx_normalize_first_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0); - znx_normalize_first_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1); + znx_normalize_first_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0); + znx_normalize_first_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); @@ -807,16 +915,16 @@ mod tests { ]; let mut c1: [i64; 4] = c0; - let basek = 12; + let base2k = 12; - znx_normalize_middle_step_inplace_ref(basek, 0, &mut y0, &mut c0); - znx_normalize_middle_step_inplace_avx(basek, 0, &mut y1, &mut c1); + znx_normalize_middle_step_inplace_ref(base2k, 0, &mut y0, &mut c0); + znx_normalize_middle_step_inplace_avx(base2k, 0, &mut y1, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); - znx_normalize_middle_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0); - znx_normalize_middle_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1); + znx_normalize_middle_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0); + znx_normalize_middle_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); @@ -852,16 +960,16 @@ mod tests { ]; let mut c1: [i64; 4] = c0; - let basek = 12; + let base2k = 12; - znx_normalize_final_step_inplace_ref(basek, 0, &mut y0, &mut c0); - znx_normalize_final_step_inplace_avx(basek, 0, &mut y1, &mut c1); + znx_normalize_final_step_inplace_ref(base2k, 0, &mut y0, &mut c0); + znx_normalize_final_step_inplace_avx(base2k, 0, &mut y1, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); - znx_normalize_final_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0); - znx_normalize_final_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1); + znx_normalize_final_step_inplace_ref(base2k, base2k - 1, &mut y0, &mut c0); + znx_normalize_final_step_inplace_avx(base2k, base2k - 1, &mut y1, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); @@ -898,16 +1006,16 @@ mod tests { ]; let mut c1: [i64; 4] = c0; - let basek = 12; + let base2k = 12; - znx_normalize_first_step_ref(basek, 0, &mut y0, &a, &mut c0); - znx_normalize_first_step_avx(basek, 0, &mut y1, &a, &mut c1); + znx_normalize_first_step_ref(base2k, 0, &mut y0, &a, &mut c0); + znx_normalize_first_step_avx(base2k, 0, &mut y1, &a, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); - znx_normalize_first_step_ref(basek, basek - 1, &mut y0, &a, &mut c0); - znx_normalize_first_step_avx(basek, basek - 1, &mut y1, &a, &mut c1); + znx_normalize_first_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0); + znx_normalize_first_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); @@ -944,16 +1052,16 @@ mod tests { ]; let mut c1: [i64; 4] = c0; - let basek = 12; + let base2k = 12; - znx_normalize_middle_step_ref(basek, 0, &mut y0, &a, &mut c0); - znx_normalize_middle_step_avx(basek, 0, &mut y1, &a, &mut c1); + znx_normalize_middle_step_ref(base2k, 0, &mut y0, &a, &mut c0); + znx_normalize_middle_step_avx(base2k, 0, &mut y1, &a, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); - znx_normalize_middle_step_ref(basek, basek - 1, &mut y0, &a, &mut c0); - znx_normalize_middle_step_avx(basek, basek - 1, &mut y1, &a, &mut c1); + znx_normalize_middle_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0); + znx_normalize_middle_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); @@ -990,16 +1098,16 @@ mod tests { ]; let mut c1: [i64; 4] = c0; - let basek = 12; + let base2k = 12; - znx_normalize_final_step_ref(basek, 0, &mut y0, &a, &mut c0); - znx_normalize_final_step_avx(basek, 0, &mut y1, &a, &mut c1); + znx_normalize_final_step_ref(base2k, 0, &mut y0, &a, &mut c0); + znx_normalize_final_step_avx(base2k, 0, &mut y1, &a, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); - znx_normalize_final_step_ref(basek, basek - 1, &mut y0, &a, &mut c0); - znx_normalize_final_step_avx(basek, basek - 1, &mut y1, &a, &mut c1); + znx_normalize_final_step_ref(base2k, base2k - 1, &mut y0, &a, &mut c0); + znx_normalize_final_step_avx(base2k, base2k - 1, &mut y1, &a, &mut c1); assert_eq!(y0, y1); assert_eq!(c0, c1); @@ -1015,4 +1123,86 @@ mod tests { test_znx_normalize_final_step_avx_internal(); } } + + #[target_feature(enable = "avx2")] + fn znx_extract_digit_addmul_internal() { + let mut y0: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let mut y1: [i64; 4] = y0; + + let mut c0: [i64; 4] = [ + 621182201135793202, + 9000856573317006236, + 5542252755421113668, + -6036847263131690631, + ]; + let mut c1: [i64; 4] = c0; + + let base2k: usize = 12; + + znx_extract_digit_addmul_ref(base2k, 0, &mut y0, &mut c0); + znx_extract_digit_addmul_avx(base2k, 0, &mut y1, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + + znx_extract_digit_addmul_ref(base2k, base2k - 1, &mut y0, &mut c0); + znx_extract_digit_addmul_avx(base2k, base2k - 1, &mut y1, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + } + + #[test] + fn test_znx_extract_digit_addmul_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + znx_extract_digit_addmul_internal(); + } + } + + #[target_feature(enable = "avx2")] + fn znx_normalize_digit_internal() { + let mut y0: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let mut y1: [i64; 4] = y0; + + let mut c0: [i64; 4] = [ + 621182201135793202, + 9000856573317006236, + 5542252755421113668, + -6036847263131690631, + ]; + let mut c1: [i64; 4] = c0; + + let base2k: usize = 12; + + znx_normalize_digit_ref(base2k, &mut y0, &mut c0); + znx_normalize_digit_avx(base2k, &mut y1, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + } + + #[test] + fn test_znx_normalize_digit_internal_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + znx_normalize_digit_internal(); + } + } } diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/sub.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/sub.rs index 509149c..ba9d6d0 100644 --- a/poulpy-backend/src/cpu_fft64_avx/znx_avx/sub.rs +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/sub.rs @@ -41,7 +41,7 @@ pub fn znx_sub_avx(res: &mut [i64], a: &[i64], b: &[i64]) { /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) { +pub fn znx_sub_inplace_avx(res: &mut [i64], a: &[i64]) { #[cfg(debug_assertions)] { assert_eq!(res.len(), a.len()); @@ -67,9 +67,9 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) { // tail if !res.len().is_multiple_of(4) { - use poulpy_hal::reference::znx::znx_sub_ab_inplace_ref; + use poulpy_hal::reference::znx::znx_sub_inplace_ref; - znx_sub_ab_inplace_ref(&mut res[span << 2..], &a[span << 2..]); + znx_sub_inplace_ref(&mut res[span << 2..], &a[span << 2..]); } } @@ -77,7 +77,7 @@ pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) { /// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); /// all inputs must have the same length and must not alias. #[target_feature(enable = "avx2")] -pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) { +pub fn znx_sub_negate_inplace_avx(res: &mut [i64], a: &[i64]) { #[cfg(debug_assertions)] { assert_eq!(res.len(), a.len()); @@ -103,8 +103,8 @@ pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) { // tail if !res.len().is_multiple_of(4) { - use poulpy_hal::reference::znx::znx_sub_ba_inplace_ref; + use poulpy_hal::reference::znx::znx_sub_negate_inplace_ref; - znx_sub_ba_inplace_ref(&mut res[span << 2..], &a[span << 2..]); + znx_sub_negate_inplace_ref(&mut res[span << 2..], &a[span << 2..]); } } diff --git a/poulpy-backend/src/cpu_fft64_ref/reim.rs b/poulpy-backend/src/cpu_fft64_ref/reim.rs index 411ee0a..9ce2164 100644 --- a/poulpy-backend/src/cpu_fft64_ref/reim.rs +++ b/poulpy-backend/src/cpu_fft64_ref/reim.rs @@ -1,10 +1,11 @@ use poulpy_hal::reference::fft64::{ reim::{ ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul, - ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, - ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref, reim_from_znx_i64_ref, - reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref, reim_sub_ab_inplace_ref, - reim_sub_ba_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref, reim_zero_ref, + ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx, + ReimToZnxInplace, ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref, + reim_from_znx_i64_ref, reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref, + reim_sub_inplace_ref, reim_sub_negate_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref, + reim_zero_ref, }, reim4::{ Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks, @@ -69,17 +70,17 @@ impl ReimSub for FFT64Ref { } } -impl ReimSubABInplace for FFT64Ref { +impl ReimSubInplace for FFT64Ref { #[inline(always)] - fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) { - reim_sub_ab_inplace_ref(res, a); + fn reim_sub_inplace(res: &mut [f64], a: &[f64]) { + reim_sub_inplace_ref(res, a); } } -impl ReimSubBAInplace for FFT64Ref { +impl ReimSubNegateInplace for FFT64Ref { #[inline(always)] - fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) { - reim_sub_ba_inplace_ref(res, a); + fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]) { + reim_sub_negate_inplace_ref(res, a); } } diff --git a/poulpy-backend/src/cpu_fft64_ref/scratch.rs b/poulpy-backend/src/cpu_fft64_ref/scratch.rs index 41eae29..80b228d 100644 --- a/poulpy-backend/src/cpu_fft64_ref/scratch.rs +++ b/poulpy-backend/src/cpu_fft64_ref/scratch.rs @@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8] (take_slice, rem_slice) } } else { - panic!( - "Attempted to take {} from scratch with {} aligned bytes left", - take_len, aligned_len, - ); + panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left"); } } diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs index ee213a9..fa88aaa 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs @@ -1,7 +1,8 @@ use poulpy_hal::{ api::{ - TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes, - VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes, + TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes, + VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxRshTmpBytes, + VecZnxSplitRingTmpBytes, }, layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, oep::{ @@ -12,7 +13,7 @@ use poulpy_hal::{ VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, - VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl, + VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, }, reference::vec_znx::{ @@ -23,7 +24,7 @@ use poulpy_hal::{ vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize, vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, - vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar, + vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar, vec_znx_sub_scalar_inplace, vec_znx_switch_ring, }, source::Source, @@ -43,9 +44,10 @@ where { fn vec_znx_normalize_impl( module: &Module, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -54,7 +56,7 @@ where A: VecZnxToRef, { let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_normalize::(basek, res, res_col, a, a_col, carry); + vec_znx_normalize::(res_basek, res, res_col, a_basek, a, a_col, carry); } } @@ -64,7 +66,7 @@ where { fn vec_znx_normalize_inplace_impl( module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch, @@ -72,7 +74,7 @@ where R: VecZnxToMut, { let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_normalize_inplace::(basek, res, res_col, carry); + vec_znx_normalize_inplace::(base2k, res, res_col, carry); } } @@ -143,23 +145,23 @@ unsafe impl VecZnxSubImpl for FFT64Ref { } } -unsafe impl VecZnxSubABInplaceImpl for FFT64Ref { - fn vec_znx_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxSubInplaceImpl for FFT64Ref { + fn vec_znx_sub_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, { - vec_znx_sub_ab_inplace::(res, res_col, a, a_col); + vec_znx_sub_inplace::(res, res_col, a, a_col); } } -unsafe impl VecZnxSubBAInplaceImpl for FFT64Ref { - fn vec_znx_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxSubNegateInplaceImpl for FFT64Ref { + fn vec_znx_sub_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, { - vec_znx_sub_ba_inplace::(res, res_col, a, a_col); + vec_znx_sub_negate_inplace::(res, res_col, a, a_col); } } @@ -234,9 +236,9 @@ where Module: VecZnxNormalizeTmpBytes, Scratch: TakeSlice, { - fn vec_znx_lsh_inplace_impl( + fn vec_znx_lsh_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -247,8 +249,8 @@ where R: VecZnxToMut, A: VecZnxToRef, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry); + let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::()); + vec_znx_lsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry); } } @@ -259,7 +261,7 @@ where { fn vec_znx_lsh_inplace_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, a: &mut A, a_col: usize, @@ -267,8 +269,8 @@ where ) where A: VecZnxToMut, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry); + let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::()); + vec_znx_lsh_inplace::<_, Self>(base2k, k, a, a_col, carry); } } @@ -277,9 +279,9 @@ where Module: VecZnxNormalizeTmpBytes, Scratch: TakeSlice, { - fn vec_znx_rsh_inplace_impl( + fn vec_znx_rsh_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -290,8 +292,8 @@ where R: VecZnxToMut, A: VecZnxToRef, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry); + let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::()); + vec_znx_rsh::<_, _, Self>(base2k, k, res, res_col, a, a_col, carry); } } @@ -302,7 +304,7 @@ where { fn vec_znx_rsh_inplace_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, a: &mut A, a_col: usize, @@ -310,8 +312,8 @@ where ) where A: VecZnxToMut, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry); + let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::()); + vec_znx_rsh_inplace::<_, Self>(base2k, k, a, a_col, carry); } } @@ -495,18 +497,18 @@ unsafe impl VecZnxCopyImpl for FFT64Ref { } unsafe impl VecZnxFillUniformImpl for FFT64Ref { - fn vec_znx_fill_uniform_impl(_module: &Module, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn vec_znx_fill_uniform_impl(_module: &Module, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut, { - vec_znx_fill_uniform_ref(basek, res, res_col, source) + vec_znx_fill_uniform_ref(base2k, res, res_col, source) } } unsafe impl VecZnxFillNormalImpl for FFT64Ref { fn vec_znx_fill_normal_impl( _module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -516,14 +518,14 @@ unsafe impl VecZnxFillNormalImpl for FFT64Ref { ) where R: VecZnxToMut, { - vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source); + vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source); } } unsafe impl VecZnxAddNormalImpl for FFT64Ref { fn vec_znx_add_normal_impl( _module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -533,6 +535,6 @@ unsafe impl VecZnxAddNormalImpl for FFT64Ref { ) where R: VecZnxToMut, { - vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source); + vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source); } } diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs index d5c4960..bb75c8f 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs @@ -10,15 +10,15 @@ use poulpy_hal::{ VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, - VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, - VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl, + VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl, }, reference::{ fft64::vec_znx_big::{ vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small, vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace, vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize, - vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace, + vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_inplace, vec_znx_big_sub_negate_inplace, vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace, }, znx::{znx_copy_ref, znx_zero_ref}, @@ -76,7 +76,7 @@ unsafe impl VecZnxBigFromSmallImpl for FFT64Ref { unsafe impl VecZnxBigAddNormalImpl for FFT64Ref { fn add_normal_impl>( _module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -84,7 +84,7 @@ unsafe impl VecZnxBigAddNormalImpl for FFT64Ref { sigma: f64, bound: f64, ) { - vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source); + vec_znx_big_add_normal_ref(base2k, res, res_col, k, sigma, bound, source); } } @@ -167,25 +167,25 @@ unsafe impl VecZnxBigSubImpl for FFT64Ref { } } -unsafe impl VecZnxBigSubABInplaceImpl for FFT64Ref { +unsafe impl VecZnxBigSubInplaceImpl for FFT64Ref { /// Subtracts `a` from `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef, { - vec_znx_big_sub_ab_inplace(res, res_col, a, a_col); + vec_znx_big_sub_inplace(res, res_col, a, a_col); } } -unsafe impl VecZnxBigSubBAInplaceImpl for FFT64Ref { +unsafe impl VecZnxBigSubNegateInplaceImpl for FFT64Ref { /// Subtracts `b` from `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef, { - vec_znx_big_sub_ba_inplace(res, res_col, a, a_col); + vec_znx_big_sub_negate_inplace(res, res_col, a, a_col); } } @@ -208,9 +208,9 @@ unsafe impl VecZnxBigSubSmallAImpl for FFT64Ref { } } -unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64Ref { +unsafe impl VecZnxBigSubSmallInplaceImpl for FFT64Ref { /// Subtracts `a` from `res` and stores the result on `res`. - fn vec_znx_big_sub_small_a_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef, @@ -238,9 +238,9 @@ unsafe impl VecZnxBigSubSmallBImpl for FFT64Ref { } } -unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64Ref { +unsafe impl VecZnxBigSubSmallNegateInplaceImpl for FFT64Ref { /// Subtracts `res` from `a` and stores the result on `res`. - fn vec_znx_big_sub_small_b_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef, @@ -280,9 +280,10 @@ where { fn vec_znx_big_normalize_impl( module: &Module, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -291,7 +292,7 @@ where A: VecZnxBigToRef, { let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); - vec_znx_big_normalize(basek, res, res_col, a, a_col, carry); + vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry); } } @@ -326,7 +327,7 @@ where ) where R: VecZnxBigToMut, { - let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); + let (tmp, _) = scratch.take_slice(module.vec_znx_big_automorphism_inplace_tmp_bytes() / size_of::()); vec_znx_big_automorphism_inplace(p, res, res_col, tmp); } } diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs index 646cbca..a08b728 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs @@ -5,12 +5,12 @@ use poulpy_hal::{ }, oep::{ VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, - VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, }, reference::fft64::vec_znx_dft::{ - vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, - vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume, + vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, vec_znx_dft_sub_inplace, + vec_znx_dft_sub_negate_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume, vec_znx_idft_apply_tmpa, }, }; @@ -139,23 +139,23 @@ unsafe impl VecZnxDftSubImpl for FFT64Ref { } } -unsafe impl VecZnxDftSubABInplaceImpl for FFT64Ref { - fn vec_znx_dft_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxDftSubInplaceImpl for FFT64Ref { + fn vec_znx_dft_sub_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, { - vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col); + vec_znx_dft_sub_inplace(res, res_col, a, a_col); } } -unsafe impl VecZnxDftSubBAInplaceImpl for FFT64Ref { - fn vec_znx_dft_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxDftSubNegateInplaceImpl for FFT64Ref { + fn vec_znx_dft_sub_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, { - vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col); + vec_znx_dft_sub_negate_inplace(res, res_col, a, a_col); } } diff --git a/poulpy-backend/src/cpu_fft64_ref/zn.rs b/poulpy-backend/src/cpu_fft64_ref/zn.rs index 995094b..954d559 100644 --- a/poulpy-backend/src/cpu_fft64_ref/zn.rs +++ b/poulpy-backend/src/cpu_fft64_ref/zn.rs @@ -18,21 +18,21 @@ unsafe impl ZnNormalizeInplaceImpl for FFT64Ref where Self: TakeSliceImpl, { - fn zn_normalize_inplace_impl(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) + fn zn_normalize_inplace_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) where R: ZnToMut, { let (carry, _) = scratch.take_slice(n); - zn_normalize_inplace::(n, basek, res, res_col, carry); + zn_normalize_inplace::(n, base2k, res, res_col, carry); } } unsafe impl ZnFillUniformImpl for FFT64Ref { - fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut, { - zn_fill_uniform(n, basek, res, res_col, source); + zn_fill_uniform(n, base2k, res, res_col, source); } } @@ -40,7 +40,7 @@ unsafe impl ZnFillNormalImpl for FFT64Ref { #[allow(clippy::too_many_arguments)] fn zn_fill_normal_impl( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -50,7 +50,7 @@ unsafe impl ZnFillNormalImpl for FFT64Ref { ) where R: ZnToMut, { - zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound); + zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound); } } @@ -58,7 +58,7 @@ unsafe impl ZnAddNormalImpl for FFT64Ref { #[allow(clippy::too_many_arguments)] fn zn_add_normal_impl( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -68,6 +68,6 @@ unsafe impl ZnAddNormalImpl for FFT64Ref { ) where R: ZnToMut, { - zn_add_normal(n, basek, res, res_col, k, source, sigma, bound); + zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound); } } diff --git a/poulpy-backend/src/cpu_fft64_ref/znx.rs b/poulpy-backend/src/cpu_fft64_ref/znx.rs index f248624..84d0e83 100644 --- a/poulpy-backend/src/cpu_fft64_ref/znx.rs +++ b/poulpy-backend/src/cpu_fft64_ref/znx.rs @@ -1,12 +1,14 @@ use poulpy_hal::reference::znx::{ - ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep, - ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, - ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubABInplace, - ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, - znx_negate_inplace_ref, znx_negate_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, + ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo, + ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, + ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, + ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, + ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, znx_extract_digit_addmul_ref, + znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref, znx_negate_inplace_ref, + znx_negate_ref, znx_normalize_digit_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate, - znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref, + znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref, }; use crate::cpu_fft64_ref::FFT64Ref; @@ -32,17 +34,38 @@ impl ZnxSub for FFT64Ref { } } -impl ZnxSubABInplace for FFT64Ref { +impl ZnxSubInplace for FFT64Ref { #[inline(always)] - fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) { - znx_sub_ab_inplace_ref(res, a); + fn znx_sub_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_inplace_ref(res, a); } } -impl ZnxSubBAInplace for FFT64Ref { +impl ZnxSubNegateInplace for FFT64Ref { #[inline(always)] - fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) { - znx_sub_ba_inplace_ref(res, a); + fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_negate_inplace_ref(res, a); + } +} + +impl ZnxMulAddPowerOfTwo for FFT64Ref { + #[inline(always)] + fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) { + znx_mul_add_power_of_two_ref(k, res, a); + } +} + +impl ZnxMulPowerOfTwo for FFT64Ref { + #[inline(always)] + fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) { + znx_mul_power_of_two_ref(k, res, a); + } +} + +impl ZnxMulPowerOfTwoInplace for FFT64Ref { + #[inline(always)] + fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) { + znx_mul_power_of_two_inplace_ref(k, res); } } @@ -97,56 +120,70 @@ impl ZnxSwitchRing for FFT64Ref { impl ZnxNormalizeFinalStep for FFT64Ref { #[inline(always)] - fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { - znx_normalize_final_step_ref(basek, lsh, x, a, carry); + fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_final_step_ref(base2k, lsh, x, a, carry); } } impl ZnxNormalizeFinalStepInplace for FFT64Ref { #[inline(always)] - fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { - znx_normalize_final_step_inplace_ref(basek, lsh, x, carry); + fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry); } } impl ZnxNormalizeFirstStep for FFT64Ref { #[inline(always)] - fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { - znx_normalize_first_step_ref(basek, lsh, x, a, carry); + fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_ref(base2k, lsh, x, a, carry); } } impl ZnxNormalizeFirstStepCarryOnly for FFT64Ref { #[inline(always)] - fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { - znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry); + fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry); } } impl ZnxNormalizeFirstStepInplace for FFT64Ref { #[inline(always)] - fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { - znx_normalize_first_step_inplace_ref(basek, lsh, x, carry); + fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry); } } impl ZnxNormalizeMiddleStep for FFT64Ref { #[inline(always)] - fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { - znx_normalize_middle_step_ref(basek, lsh, x, a, carry); + fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_ref(base2k, lsh, x, a, carry); } } impl ZnxNormalizeMiddleStepCarryOnly for FFT64Ref { #[inline(always)] - fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { - znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry); + fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry); } } impl ZnxNormalizeMiddleStepInplace for FFT64Ref { #[inline(always)] - fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { - znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry); + fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry); + } +} + +impl ZnxExtractDigitAddMul for FFT64Ref { + #[inline(always)] + fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) { + znx_extract_digit_addmul_ref(base2k, lsh, res, src); + } +} + +impl ZnxNormalizeDigit for FFT64Ref { + #[inline(always)] + fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) { + znx_normalize_digit_ref(base2k, res, src); } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs index 6790625..f87b264 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs @@ -6,5 +6,6 @@ mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; mod zn; +mod znx; pub struct FFT64Spqlios; diff --git a/poulpy-backend/src/cpu_spqlios/fft64/module.rs b/poulpy-backend/src/cpu_spqlios/fft64/module.rs index fbb3939..1425cec 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/module.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/module.rs @@ -3,20 +3,11 @@ use std::ptr::NonNull; use poulpy_hal::{ layouts::{Backend, Module}, oep::ModuleNewImpl, - reference::znx::{ - ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, - ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, - ZnxRotate, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, - znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, - znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, - znx_switch_ring_ref, znx_zero_ref, - }, }; use crate::cpu_spqlios::{ FFT64Spqlios, ffi::module::{MODULE, delete_module_info, new_module_info}, - znx::znx_rotate_i64, }; impl Backend for FFT64Spqlios { @@ -41,85 +32,3 @@ unsafe impl ModuleNewImpl for FFT64Spqlios { unsafe { Module::from_raw_parts(new_module_info(n, 0), n) } } } - -impl ZnxCopy for FFT64Spqlios { - fn znx_copy(res: &mut [i64], a: &[i64]) { - znx_copy_ref(res, a); - } -} - -impl ZnxZero for FFT64Spqlios { - fn znx_zero(res: &mut [i64]) { - znx_zero_ref(res); - } -} - -impl ZnxSwitchRing for FFT64Spqlios { - fn znx_switch_ring(res: &mut [i64], a: &[i64]) { - znx_switch_ring_ref(res, a); - } -} - -impl ZnxRotate for FFT64Spqlios { - fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) { - unsafe { - znx_rotate_i64(res.len() as u64, p, res.as_mut_ptr(), src.as_ptr()); - } - } -} - -impl ZnxNormalizeFinalStep for FFT64Spqlios { - #[inline(always)] - fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { - znx_normalize_final_step_ref(basek, lsh, x, a, carry); - } -} - -impl ZnxNormalizeFinalStepInplace for FFT64Spqlios { - #[inline(always)] - fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { - znx_normalize_final_step_inplace_ref(basek, lsh, x, carry); - } -} - -impl ZnxNormalizeFirstStep for FFT64Spqlios { - #[inline(always)] - fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { - znx_normalize_first_step_ref(basek, lsh, x, a, carry); - } -} - -impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios { - #[inline(always)] - fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { - znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry); - } -} - -impl ZnxNormalizeFirstStepInplace for FFT64Spqlios { - #[inline(always)] - fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { - znx_normalize_first_step_inplace_ref(basek, lsh, x, carry); - } -} - -impl ZnxNormalizeMiddleStep for FFT64Spqlios { - #[inline(always)] - fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { - znx_normalize_middle_step_ref(basek, lsh, x, a, carry); - } -} - -impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios { - #[inline(always)] - fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { - znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry); - } -} - -impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios { - #[inline(always)] - fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { - znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry); - } -} diff --git a/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs index 1013df8..9bddcb3 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs @@ -253,9 +253,6 @@ fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8] (take_slice, rem_slice) } } else { - panic!( - "Attempted to take {} from scratch with {} aligned bytes left", - take_len, aligned_len, - ); + panic!("Attempted to take {take_len} from scratch with {aligned_len} aligned bytes left"); } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs index ff46b27..c3a110a 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs @@ -1,5 +1,8 @@ use poulpy_hal::{ - api::{TakeSlice, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxSplitRingTmpBytes}, + api::{ + TakeSlice, VecZnxLshTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxRshTmpBytes, + VecZnxSplitRingTmpBytes, + }, layouts::{ Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, }, @@ -11,16 +14,16 @@ use poulpy_hal::{ VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, - VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl, + VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, }, reference::{ vec_znx::{ vec_znx_add_normal_ref, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy, vec_znx_fill_normal_ref, vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings, - vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_rotate_inplace_tmp_bytes, - vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes, - vec_znx_switch_ring, + vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_normalize_tmp_bytes, + vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, + vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring, }, znx::{znx_copy_ref, znx_zero_ref}, }, @@ -34,7 +37,7 @@ use crate::cpu_spqlios::{ unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Spqlios { fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t) as usize } + vec_znx_normalize_tmp_bytes(module.n()) } } @@ -44,9 +47,10 @@ where { fn vec_znx_normalize_impl( module: &Module, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -60,6 +64,10 @@ where #[cfg(debug_assertions)] { assert_eq!(res.n(), a.n()); + assert_eq!( + res_basek, a_basek, + "res_basek != a_basek -> base2k conversion is not supported" + ) } let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes()); @@ -67,7 +75,7 @@ where unsafe { vec_znx::vec_znx_normalize_base2k( module.ptr() as *const module_info_t, - basek as u64, + res_basek as u64, res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, @@ -86,7 +94,7 @@ where { fn vec_znx_normalize_inplace_impl( module: &Module, - basek: usize, + base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch, @@ -100,7 +108,7 @@ where unsafe { vec_znx::vec_znx_normalize_base2k( module.ptr() as *const module_info_t, - basek as u64, + base2k as u64, a.at_mut_ptr(a_col, 0), a.size() as u64, a.sl() as u64, @@ -301,8 +309,8 @@ unsafe impl VecZnxSubImpl for FFT64Spqlios { } } -unsafe impl VecZnxSubABInplaceImpl for FFT64Spqlios { - fn vec_znx_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxSubInplaceImpl for FFT64Spqlios { + fn vec_znx_sub_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -330,8 +338,8 @@ unsafe impl VecZnxSubABInplaceImpl for FFT64Spqlios { } } -unsafe impl VecZnxSubBAInplaceImpl for FFT64Spqlios { - fn vec_znx_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxSubNegateInplaceImpl for FFT64Spqlios { + fn vec_znx_sub_negate_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, @@ -512,9 +520,9 @@ where Module: VecZnxNormalizeTmpBytes, Scratch: TakeSlice, { - fn vec_znx_lsh_inplace_impl( + fn vec_znx_lsh_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -525,8 +533,8 @@ where R: VecZnxToMut, A: VecZnxToRef, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_lsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry) + let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::()); + vec_znx_lsh::<_, _, FFT64Spqlios>(base2k, k, res, res_col, a, a_col, carry) } } @@ -537,7 +545,7 @@ where { fn vec_znx_lsh_inplace_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, a: &mut A, a_col: usize, @@ -545,8 +553,8 @@ where ) where A: VecZnxToMut, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_lsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry) + let (carry, _) = scratch.take_slice(module.vec_znx_lsh_tmp_bytes() / size_of::()); + vec_znx_lsh_inplace::<_, FFT64Spqlios>(base2k, k, a, a_col, carry) } } @@ -555,9 +563,9 @@ where Module: VecZnxNormalizeTmpBytes, Scratch: TakeSlice, { - fn vec_znx_rsh_inplace_impl( + fn vec_znx_rsh_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -568,8 +576,8 @@ where R: VecZnxToMut, A: VecZnxToRef, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_rsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry) + let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::()); + vec_znx_rsh::<_, _, FFT64Spqlios>(base2k, k, res, res_col, a, a_col, carry) } } @@ -580,7 +588,7 @@ where { fn vec_znx_rsh_inplace_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, a: &mut A, a_col: usize, @@ -588,8 +596,8 @@ where ) where A: VecZnxToMut, { - let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_rsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry) + let (carry, _) = scratch.take_slice(module.vec_znx_rsh_tmp_bytes() / size_of::()); + vec_znx_rsh_inplace::<_, FFT64Spqlios>(base2k, k, a, a_col, carry) } } @@ -690,11 +698,7 @@ unsafe impl VecZnxAutomorphismInplaceImpl for FFT64Spqlios { let mut a: VecZnx<&mut [u8]> = a.to_mut(); #[cfg(debug_assertions)] { - assert!( - k & 1 != 0, - "invalid galois element: must be odd but is {}", - k - ); + assert!(k & 1 != 0, "invalid galois element: must be odd but is {k}"); } unsafe { vec_znx::vec_znx_automorphism( @@ -852,18 +856,18 @@ unsafe impl VecZnxCopyImpl for FFT64Spqlios { } unsafe impl VecZnxFillUniformImpl for FFT64Spqlios { - fn vec_znx_fill_uniform_impl(_module: &Module, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn vec_znx_fill_uniform_impl(_module: &Module, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut, { - vec_znx_fill_uniform_ref(basek, res, res_col, source) + vec_znx_fill_uniform_ref(base2k, res, res_col, source) } } unsafe impl VecZnxFillNormalImpl for FFT64Spqlios { fn vec_znx_fill_normal_impl( _module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -873,14 +877,14 @@ unsafe impl VecZnxFillNormalImpl for FFT64Spqlios { ) where R: VecZnxToMut, { - vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source); + vec_znx_fill_normal_ref(base2k, res, res_col, k, sigma, bound, source); } } unsafe impl VecZnxAddNormalImpl for FFT64Spqlios { fn vec_znx_add_normal_impl( _module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -890,6 +894,6 @@ unsafe impl VecZnxAddNormalImpl for FFT64Spqlios { ) where R: VecZnxToMut, { - vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source); + vec_znx_add_normal_ref(base2k, res, res_col, k, sigma, bound, source); } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs index 5cf8efa..8becaf6 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs @@ -10,11 +10,12 @@ use poulpy_hal::{ VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, - VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, - VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl, + VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl, }, reference::{ - vec_znx::vec_znx_add_normal_ref, + fft64::vec_znx_big::vec_znx_big_normalize, + vec_znx::{vec_znx_add_normal_ref, vec_znx_normalize_tmp_bytes}, znx::{znx_copy_ref, znx_zero_ref}, }, source::Source, @@ -70,7 +71,7 @@ unsafe impl VecZnxBigFromSmallImpl for FFT64Spqlios { unsafe impl VecZnxBigAddNormalImpl for FFT64Spqlios { fn add_normal_impl>( _module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -88,7 +89,7 @@ unsafe impl VecZnxBigAddNormalImpl for FFT64Spqlios { max_size: res.max_size, }; - vec_znx_add_normal_ref(basek, &mut res_znx, res_col, k, sigma, bound, source); + vec_znx_add_normal_ref(base2k, &mut res_znx, res_col, k, sigma, bound, source); } } @@ -266,9 +267,9 @@ unsafe impl VecZnxBigSubImpl for FFT64Spqlios { } } -unsafe impl VecZnxBigSubABInplaceImpl for FFT64Spqlios { +unsafe impl VecZnxBigSubInplaceImpl for FFT64Spqlios { /// Subtracts `a` from `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef, @@ -297,9 +298,9 @@ unsafe impl VecZnxBigSubABInplaceImpl for FFT64Spqlios { } } -unsafe impl VecZnxBigSubBAInplaceImpl for FFT64Spqlios { +unsafe impl VecZnxBigSubNegateInplaceImpl for FFT64Spqlios { /// Subtracts `b` from `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_negate_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef, @@ -370,9 +371,9 @@ unsafe impl VecZnxBigSubSmallAImpl for FFT64Spqlios { } } -unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64Spqlios { +unsafe impl VecZnxBigSubSmallInplaceImpl for FFT64Spqlios { /// Subtracts `a` from `res` and stores the result on `res`. - fn vec_znx_big_sub_small_a_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef, @@ -443,9 +444,9 @@ unsafe impl VecZnxBigSubSmallBImpl for FFT64Spqlios { } } -unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64Spqlios { +unsafe impl VecZnxBigSubSmallNegateInplaceImpl for FFT64Spqlios { /// Subtracts `res` from `a` and stores the result on `res`. - fn vec_znx_big_sub_small_b_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_negate_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef, @@ -518,7 +519,7 @@ unsafe impl VecZnxBigNegateInplaceImpl for FFT64Spqlios { unsafe impl VecZnxBigNormalizeTmpBytesImpl for FFT64Spqlios { fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize } + vec_znx_normalize_tmp_bytes(module.n()) } } @@ -528,9 +529,10 @@ where { fn vec_znx_big_normalize_impl( module: &Module, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -538,28 +540,21 @@ where R: VecZnxToMut, A: VecZnxBigToRef, { - let a: VecZnxBig<&[u8], Self> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.n(), a.n()); - } - - let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes()); - unsafe { - vec_znx::vec_znx_normalize_base2k( - module.ptr(), - basek as u64, - res.at_mut_ptr(res_col, 0), - res.size() as u64, - res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - tmp_bytes.as_mut_ptr(), - ); - } + let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); + // unsafe { + // vec_znx::vec_znx_normalize_base2k( + // module.ptr(), + // base2k as u64, + // res.at_mut_ptr(res_col, 0), + // res.size() as u64, + // res.sl() as u64, + // a.at_ptr(a_col, 0), + // a.size() as u64, + // a.sl() as u64, + // tmp_bytes.as_mut_ptr(), + // ); + // } + vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry); } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs index 8e72bf0..461d327 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs @@ -6,7 +6,7 @@ use poulpy_hal::{ }, oep::{ VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, - VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, }, reference::{ @@ -336,8 +336,8 @@ unsafe impl VecZnxDftSubImpl for FFT64Spqlios { } } -unsafe impl VecZnxDftSubABInplaceImpl for FFT64Spqlios { - fn vec_znx_dft_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxDftSubInplaceImpl for FFT64Spqlios { + fn vec_znx_dft_sub_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, @@ -363,8 +363,8 @@ unsafe impl VecZnxDftSubABInplaceImpl for FFT64Spqlios { } } -unsafe impl VecZnxDftSubBAInplaceImpl for FFT64Spqlios { - fn vec_znx_dft_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl VecZnxDftSubNegateInplaceImpl for FFT64Spqlios { + fn vec_znx_dft_sub_negate_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, diff --git a/poulpy-backend/src/cpu_spqlios/fft64/zn.rs b/poulpy-backend/src/cpu_spqlios/fft64/zn.rs index b2d0f42..adc84fd 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/zn.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/zn.rs @@ -12,7 +12,7 @@ unsafe impl ZnNormalizeInplaceImpl for FFT64Spqlios where Self: TakeSliceImpl, { - fn zn_normalize_inplace_impl(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn zn_normalize_inplace_impl(n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: ZnToMut, { @@ -23,7 +23,7 @@ where unsafe { zn64::zn64_normalize_base2k_ref( n as u64, - basek as u64, + base2k as u64, a.at_mut_ptr(a_col, 0), a.size() as u64, a.sl() as u64, @@ -37,11 +37,11 @@ where } unsafe impl ZnFillUniformImpl for FFT64Spqlios { - fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut, { - zn_fill_uniform(n, basek, res, res_col, source); + zn_fill_uniform(n, base2k, res, res_col, source); } } @@ -49,7 +49,7 @@ unsafe impl ZnFillNormalImpl for FFT64Spqlios { #[allow(clippy::too_many_arguments)] fn zn_fill_normal_impl( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -59,7 +59,7 @@ unsafe impl ZnFillNormalImpl for FFT64Spqlios { ) where R: ZnToMut, { - zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound); + zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound); } } @@ -67,7 +67,7 @@ unsafe impl ZnAddNormalImpl for FFT64Spqlios { #[allow(clippy::too_many_arguments)] fn zn_add_normal_impl( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -77,6 +77,6 @@ unsafe impl ZnAddNormalImpl for FFT64Spqlios { ) where R: ZnToMut, { - zn_add_normal(n, basek, res, res_col, k, source, sigma, bound); + zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound); } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/znx.rs b/poulpy-backend/src/cpu_spqlios/fft64/znx.rs new file mode 100644 index 0000000..c15c8ff --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/fft64/znx.rs @@ -0,0 +1,189 @@ +use poulpy_hal::reference::znx::{ + ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo, + ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, + ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, + ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, + ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, znx_extract_digit_addmul_ref, + znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref, znx_negate_inplace_ref, + znx_negate_ref, znx_normalize_digit_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, + znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, + znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate, + znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref, +}; + +use crate::FFT64Spqlios; + +impl ZnxAdd for FFT64Spqlios { + #[inline(always)] + fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) { + znx_add_ref(res, a, b); + } +} + +impl ZnxAddInplace for FFT64Spqlios { + #[inline(always)] + fn znx_add_inplace(res: &mut [i64], a: &[i64]) { + znx_add_inplace_ref(res, a); + } +} + +impl ZnxSub for FFT64Spqlios { + #[inline(always)] + fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) { + znx_sub_ref(res, a, b); + } +} + +impl ZnxSubInplace for FFT64Spqlios { + #[inline(always)] + fn znx_sub_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_inplace_ref(res, a); + } +} + +impl ZnxSubNegateInplace for FFT64Spqlios { + #[inline(always)] + fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_negate_inplace_ref(res, a); + } +} + +impl ZnxMulAddPowerOfTwo for FFT64Spqlios { + #[inline(always)] + fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) { + znx_mul_add_power_of_two_ref(k, res, a); + } +} + +impl ZnxMulPowerOfTwo for FFT64Spqlios { + #[inline(always)] + fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) { + znx_mul_power_of_two_ref(k, res, a); + } +} + +impl ZnxMulPowerOfTwoInplace for FFT64Spqlios { + #[inline(always)] + fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) { + znx_mul_power_of_two_inplace_ref(k, res); + } +} + +impl ZnxAutomorphism for FFT64Spqlios { + #[inline(always)] + fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) { + znx_automorphism_ref(p, res, a); + } +} + +impl ZnxCopy for FFT64Spqlios { + #[inline(always)] + fn znx_copy(res: &mut [i64], a: &[i64]) { + znx_copy_ref(res, a); + } +} + +impl ZnxNegate for FFT64Spqlios { + #[inline(always)] + fn znx_negate(res: &mut [i64], src: &[i64]) { + znx_negate_ref(res, src); + } +} + +impl ZnxNegateInplace for FFT64Spqlios { + #[inline(always)] + fn znx_negate_inplace(res: &mut [i64]) { + znx_negate_inplace_ref(res); + } +} + +impl ZnxRotate for FFT64Spqlios { + #[inline(always)] + fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) { + znx_rotate::(p, res, src); + } +} + +impl ZnxZero for FFT64Spqlios { + #[inline(always)] + fn znx_zero(res: &mut [i64]) { + znx_zero_ref(res); + } +} + +impl ZnxSwitchRing for FFT64Spqlios { + #[inline(always)] + fn znx_switch_ring(res: &mut [i64], a: &[i64]) { + znx_switch_ring_ref(res, a); + } +} + +impl ZnxNormalizeFinalStep for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_final_step_ref(base2k, lsh, x, a, carry); + } +} + +impl ZnxNormalizeFinalStepInplace for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry); + } +} + +impl ZnxNormalizeFirstStep for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_ref(base2k, lsh, x, a, carry); + } +} + +impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry); + } +} + +impl ZnxNormalizeFirstStepInplace for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry); + } +} + +impl ZnxNormalizeMiddleStep for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_ref(base2k, lsh, x, a, carry); + } +} + +impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry); + } +} + +impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry); + } +} + +impl ZnxExtractDigitAddMul for FFT64Spqlios { + #[inline(always)] + fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) { + znx_extract_digit_addmul_ref(base2k, lsh, res, src); + } +} + +impl ZnxNormalizeDigit for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) { + znx_normalize_digit_ref(base2k, res, src); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/tests.rs b/poulpy-backend/src/cpu_spqlios/tests.rs index 3c30b6f..bb1f8a0 100644 --- a/poulpy-backend/src/cpu_spqlios/tests.rs +++ b/poulpy-backend/src/cpu_spqlios/tests.rs @@ -5,15 +5,15 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_spqlios::FFT64Spqlios, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add, test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace, test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar, test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace, test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub, - test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace, - test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace, + test_vec_znx_sub_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_inplace, + test_vec_znx_sub_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_negate_inplace, test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar, test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace, test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh, @@ -41,7 +41,7 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_spqlios::FFT64Spqlios, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft, test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace, @@ -53,20 +53,20 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_spqlios::FFT64Spqlios, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add, test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace, test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small, test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace, test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub, - test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace, + test_vec_znx_big_sub_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_inplace, test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism, test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace, test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate, test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace, test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize, - test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace, + test_vec_znx_big_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_negate_inplace, test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a, test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace, test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b, @@ -79,13 +79,13 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_spqlios::FFT64Spqlios, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add, test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace, test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub, - test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace, - test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace, + test_vec_znx_dft_sub_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_inplace, + test_vec_znx_dft_sub_negate_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_negate_inplace, test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply, test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume, test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa, @@ -97,7 +97,7 @@ cross_backend_test_suite! { backend_ref = crate::cpu_fft64_ref::FFT64Ref, backend_test = crate::cpu_spqlios::FFT64Spqlios, size = 1 << 5, - basek = 12, + base2k = 12, tests = { test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft, test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add, diff --git a/poulpy-core/README.md b/poulpy-core/README.md index 64877fb..07d5304 100644 --- a/poulpy-core/README.md +++ b/poulpy-core/README.md @@ -26,13 +26,13 @@ fn main() { let n: usize = 1< = Module::::new(n as u64); // Allocates ciphertext & plaintexts - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, base2k, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, base2k, k_pt); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, base2k, k_pt); // CPRNG let mut source_xs: Source = Source::new([0u8; 32]); @@ -52,8 +52,8 @@ fn main() { // Scratch space let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, n, basek, ct.k()), + GLWECiphertext::encrypt_sk_scratch_space(&module, n, base2k, ct.k()) + | GLWECiphertext::decrypt_scratch_space(&module, n, base2k, ct.k()), ); // Generate secret-key @@ -64,7 +64,7 @@ fn main() { let sk_prepared: GLWESecretPrepared, FFT64> = sk.prepare_alloc(&module, scratch.borrow()); // Uniform plaintext - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_pt, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, k_pt, &mut source_xa); // Encryption ct.encrypt_sk( @@ -83,7 +83,7 @@ fn main() { pt_want.sub_inplace_ab(&module, &pt_have); // Ideal vs. actual noise - let noise_have: f64 = pt_want.data.std(basek, 0) * (ct.k() as f64).exp2(); + let noise_have: f64 = pt_want.data.std(base2k, 0) * (ct.k() as f64).exp2(); let noise_want: f64 = SIGMA; // Check diff --git a/poulpy-core/benches/external_product_glwe_fft64.rs b/poulpy-core/benches/external_product_glwe_fft64.rs index b7f6814..1d0ec94 100644 --- a/poulpy-core/benches/external_product_glwe_fft64.rs +++ b/poulpy-core/benches/external_product_glwe_fft64.rs @@ -1,5 +1,6 @@ use poulpy_core::layouts::{ - GGSWCiphertext, GLWECiphertext, GLWESecret, Infos, + Base2K, Degree, Digits, GGSWCiphertext, GGSWCiphertextLayout, GLWECiphertext, GLWECiphertextLayout, GLWESecret, Rank, Rows, + TorusPrecision, prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, }; use std::hint::black_box; @@ -18,50 +19,65 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { struct Params { log_n: usize, - basek: usize, - k_ct_in: usize, - k_ct_out: usize, - k_ggsw: usize, - rank: usize, + base2k: Base2K, + k_ct_in: TorusPrecision, + k_ct_out: TorusPrecision, + k_ggsw: TorusPrecision, + rank: Rank, } fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); - let n: usize = module.n(); - let basek: usize = p.basek; - let k_ct_in: usize = p.k_ct_in; - let k_ct_out: usize = p.k_ct_out; - let k_ggsw: usize = p.k_ggsw; - let rank: usize = p.rank; - let digits: usize = 1; + let n: Degree = Degree(module.n() as u32); + let base2k: Base2K = p.base2k; + let k_ct_in: TorusPrecision = p.k_ct_in; + let k_ct_out: TorusPrecision = p.k_ct_out; + let k_ggsw: TorusPrecision = p.k_ggsw; + let rank: Rank = p.rank; + let digits: Digits = Digits(1); - let rows: usize = 1; //(p.k_ct_in.div_ceil(p.basek); + let rows: Rows = Rows(1); //(p.k_ct_in.div_ceil(p.base2k); - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct_in, rank); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct_out, rank); - let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let ggsw_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + n, + base2k, + k: k_ggsw, + rows, + digits, + rank, + }; + + let glwe_out_layout: GLWECiphertextLayout = GLWECiphertextLayout { + n, + base2k, + k: k_ct_out, + rank, + }; + + let glwe_in_layout: GLWECiphertextLayout = GLWECiphertextLayout { + n, + base2k, + k: k_ct_in, + rank, + }; + + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_layout); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_layout); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_layout); + let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n.into(), 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe_in.k()) - | GLWECiphertext::external_product_scratch_space( - &module, - basek, - ct_glwe_out.k(), - ct_glwe_in.k(), - ct_ggsw.k(), - digits, - rank, - ), + GGSWCiphertext::encrypt_sk_scratch_space(&module, &ggsw_layout) + | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_in_layout) + | GLWECiphertext::external_product_scratch_space(&module, &glwe_out_layout, &glwe_in_layout, &ggsw_layout), ); let mut source_xs = Source::new([0u8; 32]); let mut source_xe = Source::new([0u8; 32]); let mut source_xa = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_in_layout); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); @@ -92,11 +108,11 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let params_set: Vec = vec![Params { log_n: 11, - basek: 22, - k_ct_in: 44, - k_ct_out: 44, - k_ggsw: 54, - rank: 1, + base2k: 22_u32.into(), + k_ct_in: 44_u32.into(), + k_ct_out: 44_u32.into(), + k_ggsw: 54_u32.into(), + rank: 1_u32.into(), }]; for params in params_set { @@ -113,39 +129,55 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { struct Params { log_n: usize, - basek: usize, - k_ct: usize, - k_ggsw: usize, - rank: usize, + base2k: Base2K, + k_ct: TorusPrecision, + k_ggsw: TorusPrecision, + rank: Rank, } fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); - let n = module.n(); - let basek: usize = p.basek; - let k_glwe: usize = p.k_ct; - let k_ggsw: usize = p.k_ggsw; - let rank: usize = p.rank; - let digits: usize = 1; + let n: Degree = Degree(module.n() as u32); + let base2k: Base2K = p.base2k; + let k_glwe: TorusPrecision = p.k_ct; + let k_ggsw: TorusPrecision = p.k_ggsw; + let rank: Rank = p.rank; + let digits: Digits = Digits(1); - let rows: usize = p.k_ct.div_ceil(p.basek); + let rows: Rows = p.k_ct.div_ceil(p.base2k).into(); - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_glwe, rank); - let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let ggsw_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + n, + base2k, + k: k_ggsw, + rows, + digits, + rank, + }; + + let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + n, + base2k, + k: k_glwe, + rank, + }; + + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_layout); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&glwe_layout); + let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n.into(), 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(&module, &ggsw_layout) + | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_layout) + | GLWECiphertext::external_product_inplace_scratch_space(&module, &glwe_layout, &ggsw_layout), ); - let mut source_xs = Source::new([0u8; 32]); - let mut source_xe = Source::new([0u8; 32]); - let mut source_xa = Source::new([0u8; 32]); + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_layout); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); @@ -177,10 +209,10 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let params_set: Vec = vec![Params { log_n: 12, - basek: 18, - k_ct: 54, - k_ggsw: 54, - rank: 1, + base2k: 18_u32.into(), + k_ct: 54_u32.into(), + k_ggsw: 54_u32.into(), + rank: 1_u32.into(), }]; for params in params_set { diff --git a/poulpy-core/benches/keyswitch_glwe_fft64.rs b/poulpy-core/benches/keyswitch_glwe_fft64.rs index baa8860..b2e3a3f 100644 --- a/poulpy-core/benches/keyswitch_glwe_fft64.rs +++ b/poulpy-core/benches/keyswitch_glwe_fft64.rs @@ -1,5 +1,6 @@ use poulpy_core::layouts::{ - GGLWEAutomorphismKey, GGLWESwitchingKey, GLWECiphertext, GLWESecret, Infos, + Base2K, Degree, Digits, GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWESwitchingKey, GGLWESwitchingKeyLayout, + GLWECiphertext, GLWECiphertextLayout, GLWESecret, Rank, Rows, TorusPrecision, prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }; use std::{hint::black_box, time::Duration}; @@ -17,59 +18,73 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { struct Params { log_n: usize, - basek: usize, - k_ct_in: usize, - k_ct_out: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, + base2k: Base2K, + k_ct_in: TorusPrecision, + k_ct_out: TorusPrecision, + k_ksk: TorusPrecision, + digits: Digits, + rank: Rank, } fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); - let n = module.n(); - let basek: usize = p.basek; - let k_rlwe_in: usize = p.k_ct_in; - let k_rlwe_out: usize = p.k_ct_out; - let k_grlwe: usize = p.k_ksk; - let rank_in: usize = p.rank_in; - let rank_out: usize = p.rank_out; - let digits: usize = p.digits; + let n: Degree = Degree(module.n() as u32); + let base2k: Base2K = p.base2k; + let k_glwe_in: TorusPrecision = p.k_ct_in; + let k_glwe_out: TorusPrecision = p.k_ct_out; + let k_gglwe: TorusPrecision = p.k_ksk; + let rank: Rank = p.rank; + let digits: Digits = p.digits; - let rows: usize = p.k_ct_in.div_ceil(p.basek * digits); + let rows: Rows = p.k_ct_in.div_ceil(p.base2k.0 * digits.0).into(); - let mut ksk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_grlwe, rows, digits, rank_out); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_rlwe_in, rank_in); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_rlwe_out, rank_out); + let gglwe_atk_layout: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n, + base2k, + k: k_gglwe, + rows, + rank, + digits, + }; + + let glwe_in_layout: GLWECiphertextLayout = GLWECiphertextLayout { + n, + base2k, + k: k_glwe_in, + rank, + }; + + let glwe_out_layout: GLWECiphertextLayout = GLWECiphertextLayout { + n, + base2k, + k: k_glwe_out, + rank, + }; + + let mut ksk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&gglwe_atk_layout); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_layout); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_layout); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) + GGLWESwitchingKey::encrypt_sk_scratch_space(&module, &gglwe_atk_layout) + | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_in_layout) | GLWECiphertext::keyswitch_scratch_space( &module, - basek, - ct_out.k(), - ct_in.k(), - ksk.k(), - digits, - rank_in, - rank_out, + &glwe_out_layout, + &glwe_in_layout, + &gglwe_atk_layout, ), ); - let mut source_xs = Source::new([0u8; 32]); - let mut source_xe = Source::new([0u8; 32]); - let mut source_xa = Source::new([0u8; 32]); + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&glwe_in_layout); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - ksk.encrypt_sk( &module, -1, @@ -95,18 +110,17 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { } } - let digits: usize = 1; - let basek: usize = 19; + let base2k: usize = 19; + let digits = 1; let params_set: Vec = vec![Params { log_n: 15, - basek, - k_ct_in: 874 - digits * basek, - k_ct_out: 874 - digits * basek, - k_ksk: 874, - digits, - rank_in: 1, - rank_out: 1, + base2k: base2k.into(), + k_ct_in: (874 - digits * base2k).into(), + k_ct_out: (874 - digits * base2k).into(), + k_ksk: 874_u32.into(), + digits: 1_u32.into(), + rank: 1_u32.into(), }]; for params in params_set { @@ -125,42 +139,59 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { struct Params { log_n: usize, - basek: usize, - k_ct: usize, - k_ksk: usize, - rank: usize, + base2k: Base2K, + k_ct: TorusPrecision, + k_ksk: TorusPrecision, + rank: Rank, } fn runner(p: Params) -> impl FnMut() { let module: Module = Module::::new(1 << p.log_n); - let n = module.n(); - let basek: usize = p.basek; - let k_ct: usize = p.k_ct; - let k_ksk: usize = p.k_ksk; - let rank: usize = p.rank; - let digits: usize = 1; + let n: Degree = Degree(module.n() as u32); + let base2k: Base2K = p.base2k; + let k_ct: TorusPrecision = p.k_ct; + let k_ksk: TorusPrecision = p.k_ksk; + let rank: Rank = p.rank; + let digits: Digits = Digits(1); - let rows: usize = p.k_ct.div_ceil(p.basek); + let rows: Rows = p.k_ct.div_ceil(p.base2k).into(); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let gglwe_layout: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n, + base2k, + k: k_ksk, + rows, + digits, + rank_in: rank, + rank_out: rank, + }; + + let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + n, + base2k, + k: k_ct, + rank, + }; + + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_layout); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_layout); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) - | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), digits, rank), + GGLWESwitchingKey::encrypt_sk_scratch_space(&module, &gglwe_layout) + | GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_layout) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, &glwe_layout, &gglwe_layout), ); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_in: GLWESecret> = GLWESecret::alloc(&glwe_layout); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_out: GLWESecret> = GLWESecret::alloc(&glwe_layout); sk_out.fill_ternary_prob(0.5, &mut source_xs); ksk.encrypt_sk( @@ -190,10 +221,10 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let params_set: Vec = vec![Params { log_n: 9, - basek: 18, - k_ct: 27, - k_ksk: 27, - rank: 1, + base2k: 18_u32.into(), + k_ct: 27_u32.into(), + k_ksk: 27_u32.into(), + rank: 1_u32.into(), }]; for params in params_set { diff --git a/poulpy-core/examples/encryption.rs b/poulpy-core/examples/encryption.rs index 169f6f4..a65b473 100644 --- a/poulpy-core/examples/encryption.rs +++ b/poulpy-core/examples/encryption.rs @@ -2,7 +2,8 @@ use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_core::{ GLWEOperations, SIGMA, layouts::{ - GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, + Base2K, Degree, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWEPlaintextLayout, GLWESecret, LWEInfos, Rank, + TorusPrecision, prepared::{GLWESecretPrepared, PrepareAlloc}, }, }; @@ -16,27 +17,36 @@ fn main() { // Ring degree let log_n: usize = 10; - let n: usize = 1 << log_n; + let n: Degree = Degree(1 << log_n); // Base-2-k (implicit digit decomposition) - let basek: usize = 14; + let base2k: Base2K = Base2K(14); // Ciphertext Torus precision (equivalent to ciphertext modulus) - let k_ct: usize = 27; + let k_ct: TorusPrecision = TorusPrecision(27); // Plaintext Torus precision (equivament to plaintext modulus) - let k_pt: usize = basek; + let k_pt: TorusPrecision = TorusPrecision(base2k.into()); // GLWE rank - let rank: usize = 1; + let rank: Rank = Rank(1); // Instantiate Module (DFT Tables) - let module: Module = Module::::new(n as u64); + let module: Module = Module::::new(n.0 as u64); + + let glwe_ct_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n, + base2k, + k: k_ct, + rank, + }; + + let glwe_pt_infos: GLWEPlaintextLayout = GLWEPlaintextLayout { n, base2k, k: k_pt }; // Allocates ciphertext & plaintexts - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_ct_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_pt_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_pt_infos); // CPRNG let mut source_xs: Source = Source::new([0u8; 32]); @@ -45,19 +55,19 @@ fn main() { // Scratch space let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()), + GLWECiphertext::encrypt_sk_scratch_space(&module, &glwe_ct_infos) + | GLWECiphertext::decrypt_scratch_space(&module, &glwe_ct_infos), ); // Generate secret-key - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_ct_infos); sk.fill_ternary_prob(0.5, &mut source_xs); // Backend-prepared secret let sk_prepared: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); // Uniform plaintext - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k.into(), &mut pt_want.data, 0, &mut source_xa); // Encryption ct.encrypt_sk( @@ -76,7 +86,7 @@ fn main() { pt_want.sub_inplace_ab(&module, &pt_have); // Ideal vs. actual noise - let noise_have: f64 = pt_want.data.std(basek, 0) * (ct.k() as f64).exp2(); + let noise_have: f64 = pt_want.data.std(base2k.into(), 0) * (ct.k().as_u32() as f64).exp2(); let noise_want: f64 = SIGMA; // Check diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 30ae17b..72ee643 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -1,43 +1,42 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; -use crate::layouts::{GGLWEAutomorphismKey, GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared}; +use crate::layouts::{GGLWEAutomorphismKey, GGLWELayoutInfos, GLWECiphertext, prepared::GGLWEAutomorphismKeyPrepared}; impl GGLWEAutomorphismKey> { - #[allow(clippy::too_many_arguments)] - pub fn automorphism_scratch_space( + pub fn automorphism_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + key_infos: &KEY, ) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GGLWELayoutInfos, + IN: GGLWELayoutInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + GLWECiphertext::keyswitch_scratch_space( + module, + &out_infos.glwe_layout(), + &in_infos.glwe_layout(), + key_infos, + ) } - pub fn automorphism_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize + pub fn automorphism_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GGLWELayoutInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GGLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank) + GGLWEAutomorphismKey::automorphism_scratch_space(module, out_infos, out_infos, key_infos) } } @@ -59,11 +58,15 @@ impl GGLWEAutomorphismKey { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphism - + VecZnxAutomorphismInplace, - Scratch: ScratchAvailable + TakeVecZnxDft, + + VecZnxAutomorphismInplace + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, { #[cfg(debug_assertions)] { + use crate::layouts::LWEInfos; + assert_eq!( self.rank_in(), lhs.rank_in(), @@ -93,13 +96,13 @@ impl GGLWEAutomorphismKey { ) } - let cols_out: usize = rhs.rank_out() + 1; + let cols_out: usize = (rhs.rank_out() + 1).into(); let p: i64 = lhs.p(); - let p_inv = module.galois_element_inv(p); + let p_inv: i64 = module.galois_element_inv(p); - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { + (0..self.rank_in().into()).for_each(|col_i| { + (0..self.rows().into()).for_each(|row_j| { let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i); let lhs_ct: GLWECiphertext<&[u8]> = lhs.at(row_j, col_i); @@ -118,8 +121,8 @@ impl GGLWEAutomorphismKey { }); }); - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank_in()).for_each(|col_j| { + (self.rows().min(lhs.rows()).into()..self.rows().into()).for_each(|row_i| { + (0..self.rank_in().into()).for_each(|col_j| { self.at_mut(row_i, col_j).data.zero(); }); }); @@ -143,8 +146,10 @@ impl GGLWEAutomorphismKey { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphism - + VecZnxAutomorphismInplace, - Scratch: ScratchAvailable + TakeVecZnxDft, + + VecZnxAutomorphismInplace + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -164,13 +169,13 @@ impl GGLWEAutomorphismKey { ); } - let cols_out: usize = rhs.rank_out() + 1; + let cols_out: usize = (rhs.rank_out() + 1).into(); let p: i64 = self.p(); let p_inv = module.galois_element_inv(p); - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { + (0..self.rank_in().into()).for_each(|col_i| { + (0..self.rows().into()).for_each(|row_j| { let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i); // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index d23bc7b..8a7cb54 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -1,67 +1,66 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, - VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, + VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; use crate::layouts::{ - GGSWCiphertext, GLWECiphertext, Infos, + GGLWELayoutInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared}, }; impl GGSWCiphertext> { - #[allow(clippy::too_many_arguments)] - pub fn automorphism_scratch_space( + pub fn automorphism_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits_ksk: usize, - k_tsk: usize, - digits_tsk: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + key_infos: &KEY, + tsk_infos: &TSK, ) -> usize where + OUT: GGSWInfos, + IN: GGSWInfos, + KEY: GGLWELayoutInfos, + TSK: GGLWELayoutInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { - let out_size: usize = k_out.div_ceil(basek); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); - let ks_internal: usize = - GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); - let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); + let out_size: usize = out_infos.size(); + let ci_dft: usize = module.vec_znx_dft_alloc_bytes((key_infos.rank_out() + 1).into(), out_size); + let ks_internal: usize = GLWECiphertext::keyswitch_scratch_space( + module, + &out_infos.glwe_layout(), + &in_infos.glwe_layout(), + key_infos, + ); + let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos); ci_dft + (ks_internal | expand) } - #[allow(clippy::too_many_arguments)] - pub fn automorphism_inplace_scratch_space( + pub fn automorphism_inplace_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits_ksk: usize, - k_tsk: usize, - digits_tsk: usize, - rank: usize, + out_infos: &OUT, + key_infos: &KEY, + tsk_infos: &TSK, ) -> usize where + OUT: GGSWInfos, + KEY: GGLWELayoutInfos, + TSK: GGLWELayoutInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { - GGSWCiphertext::automorphism_scratch_space( - module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, - ) + GGSWCiphertext::automorphism_scratch_space(module, out_infos, out_infos, key_infos, tsk_infos) } } @@ -88,13 +87,18 @@ impl GGSWCiphertext { + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig, + + VecZnxIdftApplyTmpA + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig + TakeVecZnx, { #[cfg(debug_assertions)] { - assert_eq!(self.n(), auto_key.n()); - assert_eq!(lhs.n(), auto_key.n()); + use crate::layouts::{GLWEInfos, LWEInfos}; + + assert_eq!(self.n(), module.n() as u32); + assert_eq!(lhs.n(), module.n() as u32); + assert_eq!(auto_key.n(), module.n() as u32); + assert_eq!(tensor_key.n(), module.n() as u32); assert_eq!( self.rank(), @@ -105,36 +109,23 @@ impl GGSWCiphertext { ); assert_eq!( self.rank(), - auto_key.rank(), + auto_key.rank_out(), "ggsw_in rank: {} != auto_key rank: {}", self.rank(), - auto_key.rank() + auto_key.rank_out() ); assert_eq!( self.rank(), - tensor_key.rank(), + tensor_key.rank_out(), "ggsw_in rank: {} != tensor_key rank: {}", self.rank(), - tensor_key.rank() + tensor_key.rank_out() ); - assert!( - scratch.available() - >= GGSWCiphertext::automorphism_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - auto_key.k(), - auto_key.digits(), - tensor_key.k(), - tensor_key.digits(), - self.rank(), - ) - ) + assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self, lhs, auto_key, tensor_key)) }; // Keyswitch the j-th row of the col 0 - (0..lhs.rows()).for_each(|row_i| { + (0..lhs.rows().into()).for_each(|row_i| { // Key-switch column 0, i.e. // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) self.at_mut(row_i, 0) @@ -164,11 +155,12 @@ impl GGSWCiphertext { + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig, + + VecZnxIdftApplyTmpA + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig + TakeVecZnx, { // Keyswitch the j-th row of the col 0 - (0..self.rows()).for_each(|row_i| { + (0..self.rows().into()).for_each(|row_i| { // Key-switch column 0, i.e. // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) self.at_mut(row_i, 0) diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index 1d8077a..f58d46e 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -1,44 +1,38 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace, VecZnxBigSubSmallBInplace, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, + VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallInplace, + VecZnxBigSubSmallNegateInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig}, }; -use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared}; +use crate::layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared}; impl GLWECiphertext> { - #[allow(clippy::too_many_arguments)] - pub fn automorphism_scratch_space( + pub fn automorphism_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + key_infos: &KEY, ) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + IN: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + Self::keyswitch_scratch_space(module, out_infos, in_infos, key_infos) } - pub fn automorphism_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize + pub fn automorphism_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + Self::keyswitch_inplace_scratch_space(module, out_infos, key_infos) } } @@ -59,11 +53,13 @@ impl GLWECiphertext { + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxAutomorphismInplace + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.keyswitch(module, lhs, &rhs.key, scratch); - (0..self.rank() + 1).for_each(|i| { + (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); }) } @@ -83,11 +79,13 @@ impl GLWECiphertext { + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxAutomorphismInplace + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.keyswitch_inplace(module, &rhs.key, scratch); - (0..self.rank() + 1).for_each(|i| { + (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); }) } @@ -108,19 +106,29 @@ impl GLWECiphertext { + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigAutomorphismInplace + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..self.cols()).for_each(|i| { + (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + module.vec_znx_big_normalize( + self.base2k().into(), + &mut self.data, + i, + rhs.base2k().into(), + &res_big, + i, + scratch_1, + ); }) } @@ -139,19 +147,29 @@ impl GLWECiphertext { + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxBigAutomorphismInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigAutomorphismInplace + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch_inplace(module, &rhs.key, scratch); } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..self.cols()).for_each(|i| { + (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_add_small_inplace(&mut res_big, i, &self.data, i); - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + module.vec_znx_big_normalize( + self.base2k().into(), + &mut self.data, + i, + rhs.base2k().into(), + &res_big, + i, + scratch_1, + ); }) } @@ -172,23 +190,33 @@ impl GLWECiphertext { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallAInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigSubSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..self.cols()).for_each(|i| { + (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + module.vec_znx_big_sub_small_inplace(&mut res_big, i, &lhs.data, i); + module.vec_znx_big_normalize( + self.base2k().into(), + &mut self.data, + i, + rhs.base2k().into(), + &res_big, + i, + scratch_1, + ); }) } - pub fn automorphism_sub_ab_inplace( + pub fn automorphism_sub_inplace( &mut self, module: &Module, rhs: &GGLWEAutomorphismKeyPrepared, @@ -204,23 +232,33 @@ impl GLWECiphertext { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallAInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigSubSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch_inplace(module, &rhs.key, scratch); } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..self.cols()).for_each(|i| { + (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &self.data, i); - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + module.vec_znx_big_sub_small_inplace(&mut res_big, i, &self.data, i); + module.vec_znx_big_normalize( + self.base2k().into(), + &mut self.data, + i, + rhs.base2k().into(), + &res_big, + i, + scratch_1, + ); }) } - pub fn automorphism_sub_ba( + pub fn automorphism_sub_negate( &mut self, module: &Module, lhs: &GLWECiphertext, @@ -237,23 +275,33 @@ impl GLWECiphertext { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallBInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigSubSmallNegateInplace + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..self.cols()).for_each(|i| { + (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &lhs.data, i); + module.vec_znx_big_normalize( + self.base2k().into(), + &mut self.data, + i, + rhs.base2k().into(), + &res_big, + i, + scratch_1, + ); }) } - pub fn automorphism_sub_ba_inplace( + pub fn automorphism_sub_negate_inplace( &mut self, module: &Module, rhs: &GGLWEAutomorphismKeyPrepared, @@ -269,19 +317,29 @@ impl GLWECiphertext { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallBInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigSubSmallNegateInplace + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch_inplace(module, &rhs.key, scratch); } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); - (0..self.cols()).for_each(|i| { + (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); - module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &self.data, i); - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &self.data, i); + module.vec_znx_big_normalize( + self.base2k().into(), + &mut self.data, + i, + rhs.base2k().into(), + &res_big, + i, + scratch_1, + ); }) } } diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index 037eb6f..c38a155 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -1,31 +1,46 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; use crate::{ TakeGLWECt, - layouts::{GLWECiphertext, Infos, LWECiphertext, prepared::GLWEToLWESwitchingKeyPrepared}, + layouts::{ + GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWECiphertext, LWEInfos, Rank, + prepared::GLWEToLWESwitchingKeyPrepared, + }, }; impl LWECiphertext> { - pub fn from_glwe_scratch_space( + pub fn from_glwe_scratch_space( module: &Module, - basek: usize, - k_lwe: usize, - k_glwe: usize, - k_ksk: usize, - rank: usize, + lwe_infos: &OUT, + glwe_infos: &IN, + key_infos: &KEY, ) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: LWEInfos, + IN: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::bytes_of(module.n(), basek, k_lwe, 1) - + GLWECiphertext::keyswitch_scratch_space(module, basek, k_lwe, k_glwe, k_ksk, 1, rank, 1) + let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + n: module.n().into(), + base2k: lwe_infos.base2k(), + k: lwe_infos.k(), + rank: Rank(1), + }; + + GLWECiphertext::alloc_bytes_with( + module.n().into(), + lwe_infos.base2k(), + lwe_infos.k(), + 1u32.into(), + ) + GLWECiphertext::keyswitch_scratch_space(module, &glwe_layout, glwe_infos, key_infos) } } @@ -34,10 +49,11 @@ impl LWECiphertext { #[cfg(debug_assertions)] { assert!(self.n() <= a.n()); + assert!(self.base2k() == a.base2k()); } let min_size: usize = self.size().min(a.size()); - let n: usize = self.n(); + let n: usize = self.n().into(); self.data.zero(); (0..min_size).for_each(|i| { @@ -64,15 +80,26 @@ impl LWECiphertext { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt, + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx, { #[cfg(debug_assertions)] { - assert_eq!(self.basek(), a.basek()); - assert_eq!(a.n(), ks.n()); + assert_eq!(a.n(), module.n() as u32); + assert_eq!(ks.n(), module.n() as u32); + assert!(self.n() <= module.n() as u32); } - let (mut tmp_glwe, scratch_1) = scratch.take_glwe_ct(a.n(), a.basek(), self.k(), 1); + + let glwe_layout: GLWECiphertextLayout = GLWECiphertextLayout { + n: module.n().into(), + base2k: self.base2k(), + k: self.k(), + rank: Rank(1), + }; + + let (mut tmp_glwe, scratch_1) = scratch.take_glwe_ct(&glwe_layout); tmp_glwe.keyswitch(module, a, &ks.0, scratch_1); self.sample_extract(&tmp_glwe); } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index 538256b..a94c574 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -1,31 +1,46 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero}, }; use crate::{ TakeGLWECt, - layouts::{GLWECiphertext, Infos, LWECiphertext, prepared::LWEToGLWESwitchingKeyPrepared}, + layouts::{ + GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWECiphertext, LWEInfos, + prepared::LWEToGLWESwitchingKeyPrepared, + }, }; impl GLWECiphertext> { - pub fn from_lwe_scratch_space( + pub fn from_lwe_scratch_space( module: &Module, - basek: usize, - k_lwe: usize, - k_glwe: usize, - k_ksk: usize, - rank: usize, + glwe_infos: &OUT, + lwe_infos: &IN, + key_infos: &KEY, ) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + IN: LWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::keyswitch_scratch_space(module, basek, k_glwe, k_lwe, k_ksk, 1, 1, rank) - + GLWECiphertext::bytes_of(module.n(), basek, k_lwe, 1) + let ct: usize = GLWECiphertext::alloc_bytes_with( + module.n().into(), + key_infos.base2k(), + lwe_infos.k().max(glwe_infos.k()), + 1u32.into(), + ); + let ks: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, glwe_infos, key_infos); + if lwe_infos.base2k() == key_infos.base2k() { + ct + ks + } else { + let a_conv = VecZnx::alloc_bytes(module.n(), 1, lwe_infos.size()) + module.vec_znx_normalize_tmp_bytes(); + ct + a_conv + ks + } } } @@ -47,25 +62,68 @@ impl GLWECiphertext { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt, + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt + TakeVecZnx, { #[cfg(debug_assertions)] { - assert!(lwe.n() <= self.n()); - assert_eq!(self.basek(), self.basek()); + assert_eq!(self.n(), module.n() as u32); + assert_eq!(ksk.n(), module.n() as u32); + assert!(lwe.n() <= module.n() as u32); } - let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), lwe.basek(), lwe.k(), 1); + let (mut glwe, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout { + n: ksk.n(), + base2k: ksk.base2k(), + k: lwe.k(), + rank: 1u32.into(), + }); glwe.data.zero(); - let n_lwe: usize = lwe.n(); + let n_lwe: usize = lwe.n().into(); - (0..lwe.size()).for_each(|i| { - let data_lwe: &[i64] = lwe.data.at(0, i); - glwe.data.at_mut(0, i)[0] = data_lwe[0]; - glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); - }); + if lwe.base2k() == ksk.base2k() { + for i in 0..lwe.size() { + let data_lwe: &[i64] = lwe.data.at(0, i); + glwe.data.at_mut(0, i)[0] = data_lwe[0]; + glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, lwe.size()); + a_conv.zero(); + for j in 0..lwe.size() { + let data_lwe: &[i64] = lwe.data.at(0, j); + a_conv.at_mut(0, j)[0] = data_lwe[0] + } + + module.vec_znx_normalize( + ksk.base2k().into(), + &mut glwe.data, + 0, + lwe.base2k().into(), + &a_conv, + 0, + scratch_2, + ); + + a_conv.zero(); + for j in 0..lwe.size() { + let data_lwe: &[i64] = lwe.data.at(0, j); + a_conv.at_mut(0, j)[..n_lwe].copy_from_slice(&data_lwe[1..]); + } + + module.vec_znx_normalize( + ksk.base2k().into(), + &mut glwe.data, + 1, + lwe.base2k().into(), + &a_conv, + 0, + scratch_2, + ); + } self.keyswitch(module, &glwe, &ksk.0, scratch_1); } diff --git a/poulpy-core/src/decryption/glwe_ct.rs b/poulpy-core/src/decryption/glwe_ct.rs index c69fc56..19b4a82 100644 --- a/poulpy-core/src/decryption/glwe_ct.rs +++ b/poulpy-core/src/decryption/glwe_ct.rs @@ -6,14 +6,15 @@ use poulpy_hal::{ layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch}, }; -use crate::layouts::{GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared}; +use crate::layouts::{GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; impl GLWECiphertext> { - pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn decrypt_scratch_space(module: &Module, infos: &A) -> usize where + A: GLWEInfos, Module: VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - let size: usize = k.div_ceil(basek); + let size: usize = infos.size(); (module.vec_znx_normalize_tmp_bytes() | module.vec_znx_dft_alloc_bytes(1, size)) + module.vec_znx_dft_alloc_bytes(1, size) } } @@ -41,15 +42,15 @@ impl GLWECiphertext { assert_eq!(pt.n(), sk.n()); } - let cols: usize = self.rank() + 1; + let cols: usize = (self.rank() + 1).into(); - let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self.n(), 1, self.size()); // TODO optimize size when pt << ct + let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct c0_big.data_mut().fill(0); { (1..cols).for_each(|i| { // ci_dft = DFT(a[i]) * DFT(s[i]) - let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n(), 1, self.size()); // TODO optimize size when pt << ct + let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &self.data, i); module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); let ci_big = module.vec_znx_idft_apply_consume(ci_dft); @@ -63,9 +64,17 @@ impl GLWECiphertext { module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0); // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &c0_big, 0, scratch_1); + module.vec_znx_big_normalize( + self.base2k().into(), + &mut pt.data, + 0, + self.base2k().into(), + &c0_big, + 0, + scratch_1, + ); - pt.basek = self.basek(); + pt.base2k = self.base2k(); pt.k = pt.k().min(self.k()); } } diff --git a/poulpy-core/src/decryption/lwe_ct.rs b/poulpy-core/src/decryption/lwe_ct.rs index 7b7f1b7..57abdc6 100644 --- a/poulpy-core/src/decryption/lwe_ct.rs +++ b/poulpy-core/src/decryption/lwe_ct.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; -use crate::layouts::{Infos, LWECiphertext, LWEPlaintext, LWESecret, SetMetaData}; +use crate::layouts::{LWECiphertext, LWEInfos, LWEPlaintext, LWESecret}; impl LWECiphertext where @@ -31,13 +31,13 @@ where .sum::(); }); module.zn_normalize_inplace( - pt.n(), - self.basek(), + 1, + self.base2k().into(), &mut pt.data, 0, ScratchOwned::alloc(size_of::()).borrow(), ); - pt.set_basek(self.basek()); - pt.set_k(self.k().min(pt.size() * self.basek())); + pt.base2k = self.base2k(); + pt.k = crate::layouts::TorusPrecision(self.k().0.min(pt.size() as u32 * self.base2k().0)); } } diff --git a/poulpy-core/src/encryption/compressed/gglwe_atk.rs b/poulpy-core/src/encryption/compressed/gglwe_atk.rs index 5439a50..cd0da39 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_atk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_atk.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -12,18 +12,20 @@ use poulpy_hal::{ use crate::{ TakeGLWESecret, TakeGLWESecretPrepared, layouts::{ - GLWESecret, + GGLWELayoutInfos, GLWEInfos, GLWESecret, LWEInfos, compressed::{GGLWEAutomorphismKeyCompressed, GGLWESwitchingKeyCompressed}, }, }; impl GGLWEAutomorphismKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes, { - GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, basek, k, rank, rank) - + GLWESecret::bytes_of(module.n(), rank) + assert_eq!(module.n() as u32, infos.n()); + GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, infos) + + GLWESecret::alloc_bytes_with(infos.n(), infos.rank_out()) } } @@ -49,7 +51,7 @@ impl GGLWEAutomorphismKeyCompressed { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -60,26 +62,21 @@ impl GGLWEAutomorphismKeyCompressed { { #[cfg(debug_assertions)] { - use crate::layouts::Infos; - assert_eq!(self.n(), sk.n()); assert_eq!(self.rank_out(), self.rank_in()); - assert_eq!(sk.rank(), self.rank()); + assert_eq!(sk.rank(), self.rank_out()); assert!( - scratch.available() - >= GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), - "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", + scratch.available() >= GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {}", scratch.available(), - self.rank(), - self.size(), - GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self) ) } let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); { - (0..self.rank()).for_each(|i| { + (0..self.rank_out().into()).for_each(|i| { module.vec_znx_automorphism( module.galois_element_inv(p), &mut sk_out.data.as_vec_znx_mut(), diff --git a/poulpy-core/src/encryption/compressed/gglwe_ct.rs b/poulpy-core/src/encryption/compressed/gglwe_ct.rs index 1951bf4..ef58698 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ct.rs @@ -2,7 +2,7 @@ use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, source::Source, @@ -11,15 +11,16 @@ use poulpy_hal::{ use crate::{ TakeGLWEPt, encryption::{SIGMA, glwe_encrypt_sk_internal}, - layouts::{GGLWECiphertext, Infos, compressed::GGLWECiphertextCompressed, prepared::GLWESecretPrepared}, + layouts::{GGLWECiphertext, GGLWELayoutInfos, LWEInfos, compressed::GGLWECiphertextCompressed, prepared::GLWESecretPrepared}, }; impl GGLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + GGLWECiphertext::encrypt_sk_scratch_space(module, infos) } } @@ -42,7 +43,7 @@ impl GGLWECiphertextCompressed { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -56,7 +57,7 @@ impl GGLWECiphertextCompressed { assert_eq!( self.rank_in(), - pt.cols(), + pt.cols() as u32, "self.rank_in(): {} != pt.cols(): {}", self.rank_in(), pt.cols() @@ -69,36 +70,33 @@ impl GGLWECiphertextCompressed { sk.rank() ); assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n(), sk.n()); + assert_eq!(pt.n() as u32, sk.n()); assert!( - scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k()), - "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", + scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self), + "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space: {}", scratch.available(), - self.rank(), - self.size(), - GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k()) + GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self) ); assert!( - self.rows() * self.digits() * self.basek() <= self.k(), - "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}", + self.rows().0 * self.digits().0 * self.base2k().0 <= self.k().0, + "self.rows() : {} * self.digits() : {} * self.base2k() : {} = {} >= self.k() = {}", self.rows(), self.digits(), - self.basek(), - self.rows() * self.digits() * self.basek(), + self.base2k(), + self.rows().0 * self.digits().0 * self.base2k().0, self.k() ); } - let rows: usize = self.rows(); - let digits: usize = self.digits(); - let basek: usize = self.basek(); - let k: usize = self.k(); - let rank_in: usize = self.rank_in(); - let cols: usize = self.rank_out() + 1; + let rows: usize = self.rows().into(); + let digits: usize = self.digits().into(); + let base2k: usize = self.base2k().into(); + let rank_in: usize = self.rank_in().into(); + let cols: usize = (self.rank_out() + 1).into(); let mut source_xa = Source::new(seed); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self); (0..rank_in).for_each(|col_i| { (0..rows).for_each(|row_i| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt @@ -110,15 +108,15 @@ impl GGLWECiphertextCompressed { pt, col_i, ); - module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scrach_1); + module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); let (seed, mut source_xa_tmp) = source_xa.branch(); self.seed[col_i * rows + row_i] = seed; glwe_encrypt_sk_internal( module, - self.basek(), - self.k(), + self.base2k().into(), + self.k().into(), &mut self.at_mut(row_i, col_i).data, cols, true, diff --git a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs index 7f4d81f..b2d12a3 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, source::Source, @@ -11,23 +11,21 @@ use poulpy_hal::{ use crate::{ TakeGLWESecretPrepared, - layouts::{GGLWECiphertext, GLWESecret, compressed::GGLWESwitchingKeyCompressed, prepared::GLWESecretPrepared}, + layouts::{ + Degree, GGLWECiphertext, GGLWELayoutInfos, GLWEInfos, GLWESecret, LWEInfos, compressed::GGLWESwitchingKeyCompressed, + prepared::GLWESecretPrepared, + }, }; impl GGLWESwitchingKeyCompressed> { - pub fn encrypt_sk_scratch_space( - module: &Module, - basek: usize, - k: usize, - rank_in: usize, - rank_out: usize, - ) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes, { - (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | ScalarZnx::alloc_bytes(module.n(), 1)) - + ScalarZnx::alloc_bytes(module.n(), rank_in) - + GLWESecretPrepared::bytes_of(module, rank_out) + (GGLWECiphertext::encrypt_sk_scratch_space(module, infos) | ScalarZnx::alloc_bytes(module.n(), 1)) + + ScalarZnx::alloc_bytes(module.n(), infos.rank_in().into()) + + GLWESecretPrepared::alloc_bytes_with(module, infos.rank_out()) } } @@ -52,7 +50,7 @@ impl GGLWESwitchingKeyCompressed { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -63,35 +61,22 @@ impl GGLWESwitchingKeyCompressed { { #[cfg(debug_assertions)] { - use crate::layouts::{GGLWESwitchingKey, Infos}; + use crate::layouts::GGLWESwitchingKey; - assert!(sk_in.n() <= module.n()); - assert!(sk_out.n() <= module.n()); + assert!(sk_in.n().0 <= module.n() as u32); + assert!(sk_out.n().0 <= module.n() as u32); assert!( - scratch.available() - >= GGLWESwitchingKey::encrypt_sk_scratch_space( - module, - self.basek(), - self.k(), - self.rank_in(), - self.rank_out() - ), + scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(module, self), "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", scratch.available(), - GGLWESwitchingKey::encrypt_sk_scratch_space( - module, - self.basek(), - self.k(), - self.rank_in(), - self.rank_out() - ) + GGLWESwitchingKey::encrypt_sk_scratch_space(module, self) ) } - let n: usize = sk_in.n().max(sk_out.n()); + let n: usize = sk_in.n().max(sk_out.n()).into(); - let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank()); - (0..sk_in.rank()).for_each(|i| { + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into()); + (0..sk_in.rank().into()).for_each(|i| { module.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), i, @@ -100,10 +85,10 @@ impl GGLWESwitchingKeyCompressed { ); }); - let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(n, sk_out.rank()); + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(Degree(n as u32), sk_out.rank()); { let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); - (0..sk_out.rank()).for_each(|i| { + (0..sk_out.rank().into()).for_each(|i| { module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); }); @@ -117,7 +102,7 @@ impl GGLWESwitchingKeyCompressed { source_xe, scratch_2, ); - self.sk_in_n = sk_in.n(); - self.sk_out_n = sk_out.n(); + self.sk_in_n = sk_in.n().into(); + self.sk_out_n = sk_out.n().into(); } } diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs index 52c4ff1..8bedeb4 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -11,16 +11,20 @@ use poulpy_hal::{ use crate::{ TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWETensorKey, GLWESecret, Infos, compressed::GGLWETensorKeyCompressed, prepared::Prepare}, + layouts::{ + GGLWELayoutInfos, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank, compressed::GGLWETensorKeyCompressed, + prepared::Prepare, + }, }; impl GGLWETensorKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes, { - GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k, rank) + GGLWETensorKey::encrypt_sk_scratch_space(module, infos) } } @@ -42,7 +46,7 @@ impl GGLWETensorKeyCompressed { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -63,37 +67,38 @@ impl GGLWETensorKeyCompressed { { #[cfg(debug_assertions)] { - assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.rank_out(), sk.rank()); assert_eq!(self.n(), sk.n()); } - let n: usize = sk.n(); - let rank: usize = self.rank(); + let n: usize = sk.n().into(); + let rank: usize = self.rank_out().into(); - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank); + let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(sk.n(), self.rank_out()); sk_dft_prep.prepare(module, sk, scratch_1); let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1); - (0..rank).for_each(|i| { + for i in 0..rank { module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - }); + } let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(sk.n(), Rank(1)); let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1); let mut source_xa: Source = Source::new(seed_xa); - (0..rank).for_each(|i| { - (i..rank).for_each(|j| { + for i in 0..rank { + for j in i..rank { module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); module.vec_znx_big_normalize( - self.basek(), + self.base2k().into(), &mut sk_ij.data.as_vec_znx_mut(), 0, + self.base2k().into(), &sk_ij_big, 0, scratch_5, @@ -103,7 +108,7 @@ impl GGLWETensorKeyCompressed { self.at_mut(i, j) .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5); - }); - }) + } + } } } diff --git a/poulpy-core/src/encryption/compressed/ggsw_ct.rs b/poulpy-core/src/encryption/compressed/ggsw_ct.rs index 9d4efa2..59d669f 100644 --- a/poulpy-core/src/encryption/compressed/ggsw_ct.rs +++ b/poulpy-core/src/encryption/compressed/ggsw_ct.rs @@ -2,7 +2,7 @@ use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, source::Source, @@ -11,15 +11,18 @@ use poulpy_hal::{ use crate::{ TakeGLWEPt, encryption::{SIGMA, glwe_encrypt_sk_internal}, - layouts::{GGSWCiphertext, Infos, compressed::GGSWCiphertextCompressed, prepared::GLWESecretPrepared}, + layouts::{ + GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, compressed::GGSWCiphertextCompressed, prepared::GLWESecretPrepared, + }, }; impl GGSWCiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGSWInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, infos) } } @@ -42,7 +45,7 @@ impl GGSWCiphertextCompressed { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -56,27 +59,26 @@ impl GGSWCiphertextCompressed { assert_eq!(self.rank(), sk.rank()); assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n(), sk.n()); + assert_eq!(pt.n() as u32, sk.n()); } - let basek: usize = self.basek(); - let k: usize = self.k(); - let rank: usize = self.rank(); + let base2k: usize = self.base2k().into(); + let rank: usize = self.rank().into(); let cols: usize = rank + 1; - let digits: usize = self.digits(); + let digits: usize = self.digits().into(); - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self.n(), basek, k); + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout()); let mut source = Source::new(seed_xa); - self.seed = vec![[0u8; 32]; self.rows() * cols]; + self.seed = vec![[0u8; 32]; self.rows().0 as usize * cols]; - (0..self.rows()).for_each(|row_i| { + (0..self.rows().into()).for_each(|row_i| { tmp_pt.data.zero(); // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0); - module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_1); + module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); (0..rank + 1).for_each(|col_j| { // rlwe encrypt of vec_znx_pt into vec_znx_ct @@ -87,8 +89,8 @@ impl GGSWCiphertextCompressed { glwe_encrypt_sk_internal( module, - self.basek(), - self.k(), + self.base2k().into(), + self.k().into(), &mut self.at_mut(row_i, col_j).data, cols, true, diff --git a/poulpy-core/src/encryption/compressed/glwe_ct.rs b/poulpy-core/src/encryption/compressed/glwe_ct.rs index 6b20eba..834f968 100644 --- a/poulpy-core/src/encryption/compressed/glwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/glwe_ct.rs @@ -2,7 +2,7 @@ use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -10,15 +10,18 @@ use poulpy_hal::{ use crate::{ encryption::{SIGMA, glwe_ct::glwe_encrypt_sk_internal}, - layouts::{GLWECiphertext, GLWEPlaintext, Infos, compressed::GLWECiphertextCompressed, prepared::GLWESecretPrepared}, + layouts::{ + GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, compressed::GLWECiphertextCompressed, prepared::GLWESecretPrepared, + }, }; impl GLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GLWEInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + GLWECiphertext::encrypt_sk_scratch_space(module, infos) } } @@ -40,7 +43,7 @@ impl GLWECiphertextCompressed { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -68,7 +71,7 @@ impl GLWECiphertextCompressed { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -77,11 +80,11 @@ impl GLWECiphertextCompressed { Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { let mut source_xa = Source::new(seed_xa); - let cols: usize = self.rank() + 1; + let cols: usize = (self.rank() + 1).into(); glwe_encrypt_sk_internal( module, - self.basek(), - self.k(), + self.base2k().into(), + self.k().into(), &mut self.data, cols, true, diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs index 6eab79a..308038c 100644 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ b/poulpy-core/src/encryption/gglwe_atk.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -11,19 +11,33 @@ use poulpy_hal::{ use crate::{ TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWEAutomorphismKey, GGLWESwitchingKey, GLWESecret}, + layouts::{GGLWEAutomorphismKey, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos}, }; impl GGLWEAutomorphismKey> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecret::bytes_of(module.n(), rank) + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEAutomorphismKey" + ); + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + GLWESecret::alloc_bytes(&infos.glwe_layout()) } - pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { - GGLWESwitchingKey::encrypt_pk_scratch_space(module, _basek, _k, _rank, _rank) + pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + assert_eq!( + _infos.rank_in(), + _infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEAutomorphismKey" + ); + GGLWESwitchingKey::encrypt_pk_scratch_space(module, _infos) } } @@ -46,7 +60,7 @@ impl GGLWEAutomorphismKey { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -60,26 +74,23 @@ impl GGLWEAutomorphismKey { { #[cfg(debug_assertions)] { - use crate::layouts::Infos; + use crate::layouts::{GLWEInfos, LWEInfos}; assert_eq!(self.n(), sk.n()); assert_eq!(self.rank_out(), self.rank_in()); - assert_eq!(sk.rank(), self.rank()); + assert_eq!(sk.rank(), self.rank_out()); assert!( - scratch.available() - >= GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), - "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", + scratch.available() >= GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {:?}", scratch.available(), - self.rank(), - self.size(), - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self) ) } let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); { - (0..self.rank()).for_each(|i| { + (0..self.rank_out().into()).for_each(|i| { module.vec_znx_automorphism( module.galois_element_inv(p), &mut sk_out.data.as_vec_znx_mut(), diff --git a/poulpy-core/src/encryption/gglwe_ct.rs b/poulpy-core/src/encryption/gglwe_ct.rs index 50dca97..c368892 100644 --- a/poulpy-core/src/encryption/gglwe_ct.rs +++ b/poulpy-core/src/encryption/gglwe_ct.rs @@ -2,7 +2,7 @@ use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, source::Source, @@ -10,19 +10,23 @@ use poulpy_hal::{ use crate::{ TakeGLWEPt, - layouts::{GGLWECiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared}, + layouts::{GGLWECiphertext, GGLWELayoutInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}, }; impl GGLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - + (GLWEPlaintext::byte_of(module.n(), basek, k) | module.vec_znx_normalize_tmp_bytes()) + GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout()) + + (GLWEPlaintext::alloc_bytes(&infos.glwe_layout()) | module.vec_znx_normalize_tmp_bytes()) } - pub fn encrypt_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + pub fn encrypt_pk_scratch_space(_module: &Module, _infos: &A) -> usize + where + A: GGLWELayoutInfos, + { unimplemented!() } } @@ -46,7 +50,7 @@ impl GGLWECiphertext { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -60,7 +64,7 @@ impl GGLWECiphertext { assert_eq!( self.rank_in(), - pt.cols(), + pt.cols() as u32, "self.rank_in(): {} != pt.cols(): {}", self.rank_in(), pt.cols() @@ -73,33 +77,32 @@ impl GGLWECiphertext { sk.rank() ); assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n(), sk.n()); + assert_eq!(pt.n() as u32, sk.n()); assert!( - scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), + scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self), "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), - self.rank(), + self.rank_out(), self.size(), - GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) + GGLWECiphertext::encrypt_sk_scratch_space(module, self) ); assert!( - self.rows() * self.digits() * self.basek() <= self.k(), - "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}", + self.rows().0 * self.digits().0 * self.base2k().0 <= self.k().0, + "self.rows() : {} * self.digits() : {} * self.base2k() : {} = {} >= self.k() = {}", self.rows(), self.digits(), - self.basek(), - self.rows() * self.digits() * self.basek(), + self.base2k(), + self.rows().0 * self.digits().0 * self.base2k().0, self.k() ); } - let rows: usize = self.rows(); - let digits: usize = self.digits(); - let basek: usize = self.basek(); - let k: usize = self.k(); - let rank_in: usize = self.rank_in(); + let rows: usize = self.rows().into(); + let digits: usize = self.digits().into(); + let base2k: usize = self.base2k().into(); + let rank_in: usize = self.rank_in().into(); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self); // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns // // Example for ksk rank 2 to rank 3: @@ -122,7 +125,7 @@ impl GGLWECiphertext { pt, col_i, ); - module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scrach_1); + module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); // rlwe encrypt of vec_znx_pt into vec_znx_ct self.at_mut(row_i, col_i) diff --git a/poulpy-core/src/encryption/gglwe_ksk.rs b/poulpy-core/src/encryption/gglwe_ksk.rs index daf8e2e..25daa00 100644 --- a/poulpy-core/src/encryption/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/gglwe_ksk.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, source::Source, @@ -11,33 +11,28 @@ use poulpy_hal::{ use crate::{ TakeGLWESecretPrepared, - layouts::{GGLWECiphertext, GGLWESwitchingKey, GLWESecret, prepared::GLWESecretPrepared}, + layouts::{ + Degree, GGLWECiphertext, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos, + prepared::GLWESecretPrepared, + }, }; impl GGLWESwitchingKey> { - pub fn encrypt_sk_scratch_space( - module: &Module, - basek: usize, - k: usize, - rank_in: usize, - rank_out: usize, - ) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | ScalarZnx::alloc_bytes(module.n(), 1)) - + ScalarZnx::alloc_bytes(module.n(), rank_in) - + GLWESecretPrepared::bytes_of(module, rank_out) + (GGLWECiphertext::encrypt_sk_scratch_space(module, infos) | ScalarZnx::alloc_bytes(module.n(), 1)) + + ScalarZnx::alloc_bytes(module.n(), infos.rank_in().into()) + + GLWESecretPrepared::alloc_bytes(module, &infos.glwe_layout()) } - pub fn encrypt_pk_scratch_space( - module: &Module, - _basek: usize, - _k: usize, - _rank_in: usize, - _rank_out: usize, - ) -> usize { - GGLWECiphertext::encrypt_pk_scratch_space(module, _basek, _k, _rank_out) + pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + GGLWECiphertext::encrypt_pk_scratch_space(module, _infos) } } @@ -60,7 +55,7 @@ impl GGLWESwitchingKey { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -73,35 +68,20 @@ impl GGLWESwitchingKey { { #[cfg(debug_assertions)] { - use crate::layouts::Infos; - - assert!(sk_in.n() <= module.n()); - assert!(sk_out.n() <= module.n()); + assert!(sk_in.n().0 <= module.n() as u32); + assert!(sk_out.n().0 <= module.n() as u32); assert!( - scratch.available() - >= GGLWESwitchingKey::encrypt_sk_scratch_space( - module, - self.basek(), - self.k(), - self.rank_in(), - self.rank_out() - ), + scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(module, self), "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", scratch.available(), - GGLWESwitchingKey::encrypt_sk_scratch_space( - module, - self.basek(), - self.k(), - self.rank_in(), - self.rank_out() - ) + GGLWESwitchingKey::encrypt_sk_scratch_space(module, self) ) } - let n: usize = sk_in.n().max(sk_out.n()); + let n: usize = sk_in.n().max(sk_out.n()).into(); - let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank()); - (0..sk_in.rank()).for_each(|i| { + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into()); + (0..sk_in.rank().into()).for_each(|i| { module.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), i, @@ -110,10 +90,10 @@ impl GGLWESwitchingKey { ); }); - let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(n, sk_out.rank()); + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(Degree(n as u32), sk_out.rank()); { let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); - (0..sk_out.rank()).for_each(|i| { + (0..sk_out.rank().into()).for_each(|i| { module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); }); @@ -127,7 +107,7 @@ impl GGLWESwitchingKey { source_xe, scratch_2, ); - self.sk_in_n = sk_in.n(); - self.sk_out_n = sk_out.n(); + self.sk_in_n = sk_in.n().into(); + self.sk_out_n = sk_out.n().into(); } } diff --git a/poulpy-core/src/encryption/gglwe_tsk.rs b/poulpy-core/src/encryption/gglwe_tsk.rs index 2032af5..04ac9cc 100644 --- a/poulpy-core/src/encryption/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/gglwe_tsk.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -12,23 +12,24 @@ use poulpy_hal::{ use crate::{ TakeGLWESecret, TakeGLWESecretPrepared, layouts::{ - GGLWESwitchingKey, GGLWETensorKey, GLWESecret, Infos, + Degree, GGLWELayoutInfos, GGLWESwitchingKey, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank, prepared::{GLWESecretPrepared, Prepare}, }, }; impl GGLWETensorKey> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes, { - GLWESecretPrepared::bytes_of(module, rank) - + module.vec_znx_dft_alloc_bytes(rank, 1) + GLWESecretPrepared::alloc_bytes_with(module, infos.rank_out()) + + module.vec_znx_dft_alloc_bytes(infos.rank_out().into(), 1) + module.vec_znx_big_alloc_bytes(1, 1) + module.vec_znx_dft_alloc_bytes(1, 1) - + GLWESecret::bytes_of(module.n(), 1) - + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + + GLWESecret::alloc_bytes_with(Degree(module.n() as u32), Rank(1)) + + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) } } @@ -51,7 +52,7 @@ impl GGLWETensorKey { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -65,36 +66,36 @@ impl GGLWETensorKey { { #[cfg(debug_assertions)] { - assert_eq!(self.rank(), sk.rank()); + assert_eq!(self.rank_out(), sk.rank()); assert_eq!(self.n(), sk.n()); } - let n: usize = sk.n(); - - let rank: usize = self.rank(); + let n: Degree = sk.n(); + let rank: Rank = self.rank_out(); let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank); sk_dft_prep.prepare(module, sk, scratch_1); - let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1); + let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n.into(), rank.into(), 1); - (0..rank).for_each(|i| { + (0..rank.into()).for_each(|i| { module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); }); - let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1); - let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1); + let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n.into(), 1, 1); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, Rank(1)); + let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n.into(), 1, 1); - (0..rank).for_each(|i| { - (i..rank).for_each(|j| { + (0..rank.into()).for_each(|i| { + (i..rank.into()).for_each(|j| { module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); module.vec_znx_big_normalize( - self.basek(), + self.base2k().into(), &mut sk_ij.data.as_vec_znx_mut(), 0, + self.base2k().into(), &sk_ij_big, 0, scratch_5, diff --git a/poulpy-core/src/encryption/ggsw_ct.rs b/poulpy-core/src/encryption/ggsw_ct.rs index 5995a50..4f0614f 100644 --- a/poulpy-core/src/encryption/ggsw_ct.rs +++ b/poulpy-core/src/encryption/ggsw_ct.rs @@ -2,7 +2,7 @@ use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero}, source::Source, @@ -10,19 +10,20 @@ use poulpy_hal::{ use crate::{ TakeGLWEPt, - layouts::{GGSWCiphertext, GLWECiphertext, Infos, prepared::GLWESecretPrepared}, + layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GLWESecretPrepared}, }; impl GGSWCiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGSWInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - let size = k.div_ceil(basek); - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) - + VecZnx::alloc_bytes(module.n(), rank + 1, size) + let size = infos.size(); + GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout()) + + VecZnx::alloc_bytes(module.n(), (infos.rank() + 1).into(), size) + VecZnx::alloc_bytes(module.n(), 1, size) - + module.vec_znx_dft_alloc_bytes(rank + 1, size) + + module.vec_znx_dft_alloc_bytes((infos.rank() + 1).into(), size) } } @@ -45,7 +46,7 @@ impl GGSWCiphertext { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -59,22 +60,21 @@ impl GGSWCiphertext { assert_eq!(self.rank(), sk.rank()); assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n(), sk.n()); + assert_eq!(pt.n() as u32, sk.n()); } - let basek: usize = self.basek(); - let k: usize = self.k(); - let rank: usize = self.rank(); - let digits: usize = self.digits(); + let base2k: usize = self.base2k().into(); + let rank: usize = self.rank().into(); + let digits: usize = self.digits().into(); - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self.n(), basek, k); + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout()); - (0..self.rows()).for_each(|row_i| { + (0..self.rows().into()).for_each(|row_i| { tmp_pt.data.zero(); // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0); - module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_1); + module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); (0..rank + 1).for_each(|col_j| { // rlwe encrypt of vec_znx_pt into vec_znx_ct diff --git a/poulpy-core/src/encryption/glwe_ct.rs b/poulpy-core/src/encryption/glwe_ct.rs index ddb202e..8ecacc6 100644 --- a/poulpy-core/src/encryption/glwe_ct.rs +++ b/poulpy-core/src/encryption/glwe_ct.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero}, source::Source, @@ -13,26 +13,30 @@ use crate::{ dist::Distribution, encryption::{SIGMA, SIGMA_BOUND}, layouts::{ - GLWECiphertext, GLWEPlaintext, Infos, + GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared}, }, }; impl GLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GLWEInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - let size: usize = k.div_ceil(basek); + let size: usize = infos.size(); + assert_eq!(module.n() as u32, infos.n()); module.vec_znx_normalize_tmp_bytes() + 2 * VecZnx::alloc_bytes(module.n(), 1, size) + module.vec_znx_dft_alloc_bytes(1, size) } - pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_pk_scratch_space(module: &Module, infos: &A) -> usize where + A: GLWEInfos, Module: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes, { - let size: usize = k.div_ceil(basek); + let size: usize = infos.size(); + assert_eq!(module.n() as u32, infos.n()); ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size)) | ScalarZnx::alloc_bytes(module.n(), 1)) + module.svp_ppol_alloc_bytes(1) @@ -58,7 +62,7 @@ impl GLWECiphertext { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -72,10 +76,10 @@ impl GLWECiphertext { assert_eq!(sk.n(), self.n()); assert_eq!(pt.n(), self.n()); assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self), "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) + GLWECiphertext::encrypt_sk_scratch_space(module, self) ) } @@ -97,7 +101,7 @@ impl GLWECiphertext { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -110,10 +114,10 @@ impl GLWECiphertext { assert_eq!(self.rank(), sk.rank()); assert_eq!(sk.n(), self.n()); assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self), "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) + GLWECiphertext::encrypt_sk_scratch_space(module, self) ) } self.encrypt_sk_internal( @@ -143,7 +147,7 @@ impl GLWECiphertext { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -151,11 +155,11 @@ impl GLWECiphertext { + VecZnxSub, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { - let cols: usize = self.rank() + 1; + let cols: usize = (self.rank() + 1).into(); glwe_encrypt_sk_internal( module, - self.basek(), - self.k(), + self.base2k().into(), + self.k().into(), &mut self.data, cols, false, @@ -235,24 +239,24 @@ impl GLWECiphertext { { #[cfg(debug_assertions)] { - assert_eq!(self.basek(), pk.basek()); + assert_eq!(self.base2k(), pk.base2k()); assert_eq!(self.n(), pk.n()); assert_eq!(self.rank(), pk.rank()); if let Some((pt, _)) = pt { - assert_eq!(pt.basek(), pk.basek()); + assert_eq!(pt.base2k(), pk.base2k()); assert_eq!(pt.n(), pk.n()); } } - let basek: usize = pk.basek(); + let base2k: usize = pk.base2k().into(); let size_pk: usize = pk.size(); - let cols: usize = self.rank() + 1; + let cols: usize = (self.rank() + 1).into(); // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n(), 1); + let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n().into(), 1); { - let (mut u, _) = scratch_1.take_scalar_znx(self.n(), 1); + let (mut u, _) = scratch_1.take_scalar_znx(self.n().into(), 1); match pk.dist { Distribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ @@ -271,7 +275,7 @@ impl GLWECiphertext { // ct[i] = pk[i] * u + ei (+ m if col = i) (0..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), 1, size_pk); + let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), 1, size_pk); // ci_dft = DFT(u) * DFT(pk[i]) module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); @@ -279,7 +283,15 @@ impl GLWECiphertext { let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft); // ci_big = u * pk[i] + e - module.vec_znx_big_add_normal(basek, &mut ci_big, 0, pk.k(), source_xe, SIGMA, SIGMA_BOUND); + module.vec_znx_big_add_normal( + base2k, + &mut ci_big, + 0, + pk.k().into(), + source_xe, + SIGMA, + SIGMA_BOUND, + ); // ci_big = u * pk[i] + e + m (if col = i) if let Some((pt, col)) = pt @@ -289,7 +301,7 @@ impl GLWECiphertext { } // ct[i] = norm(ci_big) - module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2); + module.vec_znx_big_normalize(base2k, &mut self.data, i, base2k, &ci_big, 0, scratch_2); }); } } @@ -297,7 +309,7 @@ impl GLWECiphertext { #[allow(clippy::too_many_arguments)] pub(crate) fn glwe_encrypt_sk_internal( module: &Module, - basek: usize, + base2k: usize, k: usize, ct: &mut VecZnx, cols: usize, @@ -316,7 +328,7 @@ pub(crate) fn glwe_encrypt_sk_internal + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -350,7 +362,7 @@ pub(crate) fn glwe_encrypt_sk_internal = module.vec_znx_idft_apply_consume(ci_dft); // use c[0] as buffer, which is overwritten later by the normalization step - module.vec_znx_big_normalize(basek, &mut ci, 0, &ci_big, 0, scratch_3); + module.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); // c0_tmp = -c[i] * s[i] (use c[0] as buffer) - module.vec_znx_sub_ab_inplace(&mut c0, 0, &ci, 0); + module.vec_znx_sub_inplace(&mut c0, 0, &ci, 0); }); } // c[0] += e - module.vec_znx_add_normal(basek, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND); + module.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND); // c[0] += m if col = 0 if let Some((pt, col)) = pt @@ -391,5 +403,5 @@ pub(crate) fn glwe_encrypt_sk_internal GLWEPublicKey { pub fn generate_from_sk( @@ -27,7 +27,7 @@ impl GLWEPublicKey { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -42,7 +42,7 @@ impl GLWEPublicKey { { #[cfg(debug_assertions)] { - use crate::Distribution; + use crate::{Distribution, layouts::LWEInfos}; assert_eq!(self.n(), sk.n()); @@ -52,13 +52,9 @@ impl GLWEPublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space( - module, - self.basek(), - self.k(), - )); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space(module, self)); - let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(self.n(), self.basek(), self.k(), self.rank()); + let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(self); tmp.encrypt_zero_sk(module, sk, source_xa, source_xe, scratch.borrow()); self.dist = sk.dist; } diff --git a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs index 9c72d78..01e2ca9 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, source::Source, @@ -11,17 +11,21 @@ use poulpy_hal::{ use crate::{ TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWESwitchingKey, GLWESecret, GLWEToLWESwitchingKey, LWESecret, prepared::GLWESecretPrepared}, + layouts::{ + GGLWELayoutInfos, GGLWESwitchingKey, GLWESecret, GLWEToLWESwitchingKey, LWEInfos, LWESecret, Rank, + prepared::GLWESecretPrepared, + }, }; impl GLWEToLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_in: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GLWESecretPrepared::bytes_of(module, rank_in) - + (GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank_in, 1) - | GLWESecret::bytes_of(module.n(), rank_in)) + GLWESecretPrepared::alloc_bytes_with(module, infos.rank_in()) + + (GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + | GLWESecret::alloc_bytes_with(infos.n(), infos.rank_in())) } } @@ -47,7 +51,7 @@ impl GLWEToLWESwitchingKey { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -60,12 +64,12 @@ impl GLWEToLWESwitchingKey { { #[cfg(debug_assertions)] { - assert!(sk_lwe.n() <= module.n()); + assert!(sk_lwe.n().0 <= module.n() as u32); } - let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), 1); + let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), Rank(1)); sk_lwe_as_glwe.data.zero(); - sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); self.0.encrypt_sk( diff --git a/poulpy-core/src/encryption/lwe_ct.rs b/poulpy-core/src/encryption/lwe_ct.rs index 15a9a65..4dd09ac 100644 --- a/poulpy-core/src/encryption/lwe_ct.rs +++ b/poulpy-core/src/encryption/lwe_ct.rs @@ -7,7 +7,7 @@ use poulpy_hal::{ use crate::{ encryption::{SIGMA, SIGMA_BOUND}, - layouts::{Infos, LWECiphertext, LWEPlaintext, LWESecret}, + layouts::{LWECiphertext, LWEInfos, LWEPlaintext, LWESecret}, }; impl LWECiphertext { @@ -29,10 +29,10 @@ impl LWECiphertext { assert_eq!(self.n(), sk.n()) } - let basek: usize = self.basek(); - let k: usize = self.k(); + let base2k: usize = self.base2k().into(); + let k: usize = self.k().into(); - module.zn_fill_uniform(self.n() + 1, basek, &mut self.data, 0, source_xa); + module.zn_fill_uniform((self.n() + 1).into(), base2k, &mut self.data, 0, source_xa); let mut tmp_znx: Zn> = Zn::alloc(1, 1, self.size()); @@ -57,7 +57,7 @@ impl LWECiphertext { module.zn_add_normal( 1, - basek, + base2k, &mut self.data, 0, k, @@ -68,7 +68,7 @@ impl LWECiphertext { module.zn_normalize_inplace( 1, - basek, + base2k, &mut tmp_znx, 0, ScratchOwned::alloc(size_of::()).borrow(), diff --git a/poulpy-core/src/encryption/lwe_ksk.rs b/poulpy-core/src/encryption/lwe_ksk.rs index 2df695f..c7bac02 100644 --- a/poulpy-core/src/encryption/lwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_ksk.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, source::Source, @@ -11,17 +11,36 @@ use poulpy_hal::{ use crate::{ TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWESwitchingKey, GLWESecret, Infos, LWESecret, LWESwitchingKey, prepared::GLWESecretPrepared}, + layouts::{ + Degree, GGLWELayoutInfos, GGLWESwitchingKey, GLWESecret, LWEInfos, LWESecret, LWESwitchingKey, Rank, + prepared::GLWESecretPrepared, + }, }; impl LWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GLWESecret::bytes_of(module.n(), 1) - + GLWESecretPrepared::bytes_of(module, 1) - + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, 1) + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + GLWESecret::alloc_bytes_with(Degree(module.n() as u32), Rank(1)) + + GLWESecretPrepared::alloc_bytes_with(module, Rank(1)) + + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) } } @@ -47,7 +66,7 @@ impl LWESwitchingKey { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -60,20 +79,20 @@ impl LWESwitchingKey { { #[cfg(debug_assertions)] { - assert!(sk_lwe_in.n() <= self.n()); - assert!(sk_lwe_out.n() <= self.n()); - assert!(self.n() <= module.n()); + assert!(sk_lwe_in.n().0 <= self.n().0); + assert!(sk_lwe_out.n().0 <= self.n().0); + assert!(self.n().0 <= module.n() as u32); } - let (mut sk_in_glwe, scratch_1) = scratch.take_glwe_secret(self.n(), 1); - let (mut sk_out_glwe, scratch_2) = scratch_1.take_glwe_secret(self.n(), 1); + let (mut sk_in_glwe, scratch_1) = scratch.take_glwe_secret(self.n(), Rank(1)); + let (mut sk_out_glwe, scratch_2) = scratch_1.take_glwe_secret(self.n(), Rank(1)); - sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n()].copy_from_slice(sk_lwe_out.data.at(0, 0)); - sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n()..].fill(0); + sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n().into()].copy_from_slice(sk_lwe_out.data.at(0, 0)); + sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n().into()..].fill(0); module.vec_znx_automorphism_inplace(-1, &mut sk_out_glwe.data.as_vec_znx_mut(), 0, scratch_2); - sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n()].copy_from_slice(sk_lwe_in.data.at(0, 0)); - sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n()..].fill(0); + sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n().into()].copy_from_slice(sk_lwe_in.data.at(0, 0)); + sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n().into()..].fill(0); module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data.as_vec_znx_mut(), 0, scratch_2); self.0.encrypt_sk( diff --git a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs index 95ea310..c97046a 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, source::Source, @@ -11,15 +11,22 @@ use poulpy_hal::{ use crate::{ TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWESwitchingKey, GLWESecret, LWESecret, LWEToGLWESwitchingKey}, + layouts::{Degree, GGLWELayoutInfos, GGLWESwitchingKey, GLWESecret, LWEInfos, LWESecret, LWEToGLWESwitchingKey, Rank}, }; impl LWEToGLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_out: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGLWELayoutInfos, Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, rank_out) + GLWESecret::bytes_of(module.n(), 1) + debug_assert_eq!( + infos.rank_in(), + Rank(1), + "rank_in != 1 is not supported for LWEToGLWESwitchingKey" + ); + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + + GLWESecret::alloc_bytes_with(Degree(module.n() as u32), infos.rank_in()) } } @@ -45,7 +52,7 @@ impl LWEToGLWESwitchingKey { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -58,12 +65,14 @@ impl LWEToGLWESwitchingKey { { #[cfg(debug_assertions)] { - assert!(sk_lwe.n() <= module.n()); + use crate::layouts::LWEInfos; + + assert!(sk_lwe.n().0 <= module.n() as u32); } - let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), 1); - sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); - sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n()..].fill(0); + let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), Rank(1)); + sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n().into()].copy_from_slice(sk_lwe.data.at(0, 0)); + sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n().into()..].fill(0); module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); self.0.encrypt_sk( diff --git a/poulpy-core/src/external_product/gglwe_atk.rs b/poulpy-core/src/external_product/gglwe_atk.rs index 23a48c2..611f755 100644 --- a/poulpy-core/src/external_product/gglwe_atk.rs +++ b/poulpy-core/src/external_product/gglwe_atk.rs @@ -1,42 +1,41 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; -use crate::layouts::{GGLWEAutomorphismKey, GGLWESwitchingKey, prepared::GGSWCiphertextPrepared}; +use crate::layouts::{GGLWEAutomorphismKey, GGLWELayoutInfos, GGLWESwitchingKey, GGSWInfos, prepared::GGSWCiphertextPrepared}; impl GGLWEAutomorphismKey> { - #[allow(clippy::too_many_arguments)] - pub fn external_product_scratch_space( + pub fn external_product_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - ggsw_k: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + ggsw_infos: &GGSW, ) -> usize where + OUT: GGLWELayoutInfos, + IN: GGLWELayoutInfos, + GGSW: GGSWInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, ggsw_k, digits, rank) + GGLWESwitchingKey::external_product_scratch_space(module, out_infos, in_infos, ggsw_infos) } - pub fn external_product_inplace_scratch_space( + pub fn external_product_inplace_scratch_space( module: &Module, - basek: usize, - k_out: usize, - ggsw_k: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + ggsw_infos: &GGSW, ) -> usize where + OUT: GGLWELayoutInfos, + GGSW: GGSWInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_out, ggsw_k, digits, rank) + GGLWESwitchingKey::external_product_inplace_scratch_space(module, out_infos, ggsw_infos) } } @@ -55,8 +54,9 @@ impl GGLWEAutomorphismKey { + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxIdftApplyConsume - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.key.external_product(module, &lhs.key, rhs, scratch); } @@ -74,8 +74,9 @@ impl GGLWEAutomorphismKey { + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxIdftApplyConsume - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.key.external_product_inplace(module, rhs, scratch); } diff --git a/poulpy-core/src/external_product/gglwe_ksk.rs b/poulpy-core/src/external_product/gglwe_ksk.rs index 8a07977..8e49640 100644 --- a/poulpy-core/src/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/external_product/gglwe_ksk.rs @@ -1,42 +1,46 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; -use crate::layouts::{GGLWESwitchingKey, GLWECiphertext, Infos, prepared::GGSWCiphertextPrepared}; +use crate::layouts::{GGLWELayoutInfos, GGLWESwitchingKey, GGSWInfos, GLWECiphertext, prepared::GGSWCiphertextPrepared}; impl GGLWESwitchingKey> { - #[allow(clippy::too_many_arguments)] - pub fn external_product_scratch_space( + pub fn external_product_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + ggsw_infos: &GGSW, ) -> usize where + OUT: GGLWELayoutInfos, + IN: GGLWELayoutInfos, + GGSW: GGSWInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank) + GLWECiphertext::external_product_scratch_space( + module, + &out_infos.glwe_layout(), + &in_infos.glwe_layout(), + ggsw_infos, + ) } - pub fn external_product_inplace_scratch_space( + pub fn external_product_inplace_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + ggsw_infos: &GGSW, ) -> usize where + OUT: GGLWELayoutInfos, + GGSW: GGSWInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank) + GLWECiphertext::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), ggsw_infos) } } @@ -55,11 +59,14 @@ impl GGLWESwitchingKey { + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxIdftApplyConsume - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { + use crate::layouts::GLWEInfos; + assert_eq!( self.rank_in(), lhs.rank_in(), @@ -83,15 +90,15 @@ impl GGLWESwitchingKey { ); } - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { + (0..self.rank_in().into()).for_each(|col_i| { + (0..self.rows().into()).for_each(|row_j| { self.at_mut(row_j, col_i) .external_product(module, &lhs.at(row_j, col_i), rhs, scratch); }); }); - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank_in()).for_each(|col_j| { + (self.rows().min(lhs.rows()).into()..self.rows().into()).for_each(|row_i| { + (0..self.rank_in().into()).for_each(|col_j| { self.at_mut(row_i, col_j).data.zero(); }); }); @@ -110,11 +117,14 @@ impl GGLWESwitchingKey { + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxIdftApplyConsume - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { + use crate::layouts::GLWEInfos; + assert_eq!( self.rank_out(), rhs.rank(), @@ -124,8 +134,8 @@ impl GGLWESwitchingKey { ); } - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { + (0..self.rank_in().into()).for_each(|col_i| { + (0..self.rows().into()).for_each(|row_j| { self.at_mut(row_j, col_i) .external_product_inplace(module, rhs, scratch); }); diff --git a/poulpy-core/src/external_product/ggsw_ct.rs b/poulpy-core/src/external_product/ggsw_ct.rs index 0b72877..a327058 100644 --- a/poulpy-core/src/external_product/ggsw_ct.rs +++ b/poulpy-core/src/external_product/ggsw_ct.rs @@ -1,42 +1,47 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; -use crate::layouts::{GGSWCiphertext, GLWECiphertext, Infos, prepared::GGSWCiphertextPrepared}; +use crate::layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, prepared::GGSWCiphertextPrepared}; impl GGSWCiphertext> { #[allow(clippy::too_many_arguments)] - pub fn external_product_scratch_space( + pub fn external_product_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + apply_infos: &GGSW, ) -> usize where + OUT: GGSWInfos, + IN: GGSWInfos, + GGSW: GGSWInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank) + GLWECiphertext::external_product_scratch_space( + module, + &out_infos.glwe_layout(), + &in_infos.glwe_layout(), + apply_infos, + ) } - pub fn external_product_inplace_scratch_space( + pub fn external_product_inplace_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + apply_infos: &GGSW, ) -> usize where + OUT: GGSWInfos, + GGSW: GGSWInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank) + GLWECiphertext::external_product_inplace_scratch_space(module, &out_infos.glwe_layout(), apply_infos) } } @@ -55,12 +60,13 @@ impl GGSWCiphertext { + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxIdftApplyConsume - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { - use crate::layouts::Infos; + use crate::layouts::LWEInfos; assert_eq!(lhs.n(), self.n()); assert_eq!(rhs.n(), self.n()); @@ -80,28 +86,17 @@ impl GGSWCiphertext { rhs.rank() ); - assert!( - scratch.available() - >= GGSWCiphertext::external_product_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank() - ) - ) + assert!(scratch.available() >= GGSWCiphertext::external_product_scratch_space(module, self, lhs, rhs)) } - let min_rows: usize = self.rows().min(lhs.rows()); + let min_rows: usize = self.rows().min(lhs.rows()).into(); - (0..self.rank() + 1).for_each(|col_i| { + (0..(self.rank() + 1).into()).for_each(|col_i| { (0..min_rows).for_each(|row_j| { self.at_mut(row_j, col_i) .external_product(module, &lhs.at(row_j, col_i), rhs, scratch); }); - (min_rows..self.rows()).for_each(|row_i| { + (min_rows..self.rows().into()).for_each(|row_i| { self.at_mut(row_i, col_i).data.zero(); }); }); @@ -120,11 +115,14 @@ impl GGSWCiphertext { + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxIdftApplyConsume - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { + use crate::layouts::LWEInfos; + assert_eq!(rhs.n(), self.n()); assert_eq!( self.rank(), @@ -135,8 +133,8 @@ impl GGSWCiphertext { ); } - (0..self.rank() + 1).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { + (0..(self.rank() + 1).into()).for_each(|col_i| { + (0..self.rows().into()).for_each(|row_j| { self.at_mut(row_j, col_i) .external_product_inplace(module, rhs, scratch); }); diff --git a/poulpy-core/src/external_product/glwe_ct.rs b/poulpy-core/src/external_product/glwe_ct.rs index 9164b96..a225976 100644 --- a/poulpy-core/src/external_product/glwe_ct.rs +++ b/poulpy-core/src/external_product/glwe_ct.rs @@ -1,56 +1,65 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnxBig}, + layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig}, }; -use crate::layouts::{GLWECiphertext, Infos, prepared::GGSWCiphertextPrepared}; +use crate::layouts::{GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGSWCiphertextPrepared}; impl GLWECiphertext> { #[allow(clippy::too_many_arguments)] - pub fn external_product_scratch_space( + pub fn external_product_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + apply_infos: &GGSW, ) -> usize where + OUT: GLWEInfos, + IN: GLWEInfos, + GGSW: GGSWInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, { - let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); - let out_size: usize = k_out.div_ceil(basek); - let ggsw_size: usize = k_ggsw.div_ceil(basek); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, ggsw_size); - let a_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, in_size); + let in_size: usize = in_infos + .k() + .div_ceil(apply_infos.base2k()) + .div_ceil(apply_infos.digits().into()) as usize; + let out_size: usize = out_infos.size(); + let ggsw_size: usize = apply_infos.size(); + let res_dft: usize = module.vec_znx_dft_alloc_bytes((apply_infos.rank() + 1).into(), ggsw_size); + let a_dft: usize = module.vec_znx_dft_alloc_bytes((apply_infos.rank() + 1).into(), in_size); let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( out_size, in_size, - in_size, // rows - rank + 1, // cols in - rank + 1, // cols out + in_size, // rows + (apply_infos.rank() + 1).into(), // cols in + (apply_infos.rank() + 1).into(), // cols out ggsw_size, ); - let normalize: usize = module.vec_znx_normalize_tmp_bytes(); - res_dft + a_dft + (vmp | normalize) + let normalize_big: usize = module.vec_znx_normalize_tmp_bytes(); + + if in_infos.base2k() == apply_infos.base2k() { + res_dft + a_dft + (vmp | normalize_big) + } else { + let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (apply_infos.rank() + 1).into(), in_size); + res_dft + ((a_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) + } } - pub fn external_product_inplace_scratch_space( + pub fn external_product_inplace_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + apply_infos: &GGSW, ) -> usize where + OUT: GLWEInfos, + GGSW: GGSWInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes, { - Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) + Self::external_product_scratch_space(module, out_infos, out_infos, apply_infos) } } @@ -69,10 +78,13 @@ impl GLWECiphertext { + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxIdftApplyConsume - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { - let basek: usize = self.basek(); + let basek_in: usize = lhs.base2k().into(); + let basek_ggsw: usize = rhs.base2k().into(); + let basek_out: usize = self.base2k().into(); #[cfg(debug_assertions)] { @@ -80,34 +92,22 @@ impl GLWECiphertext { assert_eq!(rhs.rank(), lhs.rank()); assert_eq!(rhs.rank(), self.rank()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); assert_eq!(rhs.n(), self.n()); assert_eq!(lhs.n(), self.n()); - assert!( - scratch.available() - >= GLWECiphertext::external_product_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank(), - ) - ); + assert!(scratch.available() >= GLWECiphertext::external_product_scratch_space(module, self, lhs, rhs)); } - let cols: usize = rhs.rank() + 1; - let digits: usize = rhs.digits(); + let cols: usize = (rhs.rank() + 1).into(); + let digits: usize = rhs.digits().into(); - let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, lhs.size().div_ceil(digits)); + let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw); + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), cols, a_size.div_ceil(digits)); a_dft.data_mut().fill(0); - { - (0..digits).for_each(|di| { + if basek_in == basek_ggsw { + for di in 0..digits { // (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits) a_dft.set_size((lhs.size() + di) / digits); @@ -120,22 +120,68 @@ impl GLWECiphertext { // noise is kept with respect to the ideal functionality. res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - (0..cols).for_each(|col_i| { - module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); - }); + for j in 0..cols { + module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &lhs.data, j); + } if di == 0 { module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2); } else { module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2); } - }); + } + } else { + let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(module.n(), cols, a_size); + + for j in 0..cols { + module.vec_znx_normalize( + basek_ggsw, + &mut a_conv, + j, + basek_in, + &lhs.data, + j, + scratch_3, + ); + } + + for di in 0..digits { + // (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits) + a_dft.set_size((a_size + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + + for j in 0..cols { + module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &a_conv, j); + } + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3); + } else { + module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3); + } + } } let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft); (0..cols).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch_1); + module.vec_znx_big_normalize( + basek_out, + &mut self.data, + i, + basek_ggsw, + &res_big, + i, + scratch_1, + ); }); } @@ -152,42 +198,32 @@ impl GLWECiphertext { + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxIdftApplyConsume - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { - let basek: usize = self.basek(); + let basek_in: usize = self.base2k().into(); + let basek_ggsw: usize = rhs.base2k().into(); #[cfg(debug_assertions)] { use poulpy_hal::api::ScratchAvailable; assert_eq!(rhs.rank(), self.rank()); - assert_eq!(self.basek(), basek); assert_eq!(rhs.n(), self.n()); - assert!( - scratch.available() - >= GLWECiphertext::external_product_scratch_space( - module, - self.basek(), - self.k(), - self.k(), - rhs.k(), - rhs.digits(), - rhs.rank(), - ) - ); + assert!(scratch.available() >= GLWECiphertext::external_product_inplace_scratch_space(module, self, rhs,)); } - let cols: usize = rhs.rank() + 1; - let digits: usize = rhs.digits(); - - let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, self.size().div_ceil(digits)); + let cols: usize = (rhs.rank() + 1).into(); + let digits: usize = rhs.digits().into(); + let a_size: usize = (self.size() * basek_in).div_ceil(basek_ggsw); + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), cols, a_size.div_ceil(digits)); a_dft.data_mut().fill(0); - { - (0..digits).for_each(|di| { + if basek_in == basek_ggsw { + for di in 0..digits { // (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits) a_dft.set_size((self.size() + di) / digits); @@ -200,29 +236,68 @@ impl GLWECiphertext { // noise is kept with respect to the ideal functionality. res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); - (0..cols).for_each(|col_i| { - module.vec_znx_dft_apply( - digits, - digits - 1 - di, - &mut a_dft, - col_i, - &self.data, - col_i, - ); - }); + for j in 0..cols { + module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &self.data, j); + } if di == 0 { module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2); } else { module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2); } - }); + } + } else { + let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(module.n(), cols, a_size); + + for j in 0..cols { + module.vec_znx_normalize( + basek_ggsw, + &mut a_conv, + j, + basek_in, + &self.data, + j, + scratch_3, + ); + } + + for di in 0..digits { + // (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits) + a_dft.set_size((self.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + + for j in 0..cols { + module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, j, &self.data, j); + } + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2); + } else { + module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2); + } + } } let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft); - (0..cols).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch_1); - }); + for j in 0..cols { + module.vec_znx_big_normalize( + basek_in, + &mut self.data, + j, + basek_ggsw, + &res_big, + j, + scratch_1, + ); + } } } diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index c7c1f64..93f84aa 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -3,17 +3,17 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxRotate, - VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, + VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; use crate::{ GLWEOperations, TakeGLWECt, - layouts::{GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared}, + layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared}, }; /// [GLWEPacker] enables only the fly GLWE packing @@ -40,12 +40,15 @@ impl Accumulator { /// #Arguments /// /// * `module`: static backend FFT tables. - /// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation. + /// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation. /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. /// * `rank`: rank of the GLWE ciphertext. - pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GLWEInfos, + { Self { - data: GLWECiphertext::alloc(n, basek, k, rank), + data: GLWECiphertext::alloc(infos), value: false, control: false, } @@ -63,13 +66,13 @@ impl GLWEPacker { /// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients /// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts /// can be packed. - /// * `basek`: base 2 logarithm of the GLWE ciphertext in memory digit representation. - /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. - /// * `rank`: rank of the GLWE ciphertext. - pub fn new(n: usize, log_batch: usize, basek: usize, k: usize, rank: usize) -> Self { + pub fn new(infos: &A, log_batch: usize) -> Self + where + A: GLWEInfos, + { let mut accumulators: Vec = Vec::::new(); - let log_n: usize = (usize::BITS - (n - 1).leading_zeros()) as _; - (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(n, basek, k, rank))); + let log_n: usize = infos.n().log2(); + (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos))); Self { accumulators, log_batch, @@ -87,18 +90,13 @@ impl GLWEPacker { } /// Number of scratch space bytes required to call [Self::add]. - pub fn scratch_space( - module: &Module, - basek: usize, - ct_k: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize + pub fn scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - pack_core_scratch_space(module, basek, ct_k, k_ksk, digits, rank) + pack_core_scratch_space(module, out_infos, key_infos) } pub fn galois_elements(module: &Module) -> Vec { @@ -137,17 +135,19 @@ impl GLWEPacker { + VecZnxRshInplace + VecZnxAddInplace + VecZnxNormalizeInplace - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxRotate + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace - + VecZnxBigAutomorphismInplace, + + VecZnxBigSubSmallNegateInplace + + VecZnxBigAutomorphismInplace + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { assert!( - self.counter < self.accumulators[0].data.n(), + (self.counter as u32) < self.accumulators[0].data.n(), "Packing limit of {} reached", - self.accumulators[0].data.n() >> self.log_batch + self.accumulators[0].data.n().0 as usize >> self.log_batch ); pack_core( @@ -166,7 +166,7 @@ impl GLWEPacker { where Module: VecZnxCopy, { - assert!(self.counter == self.accumulators[0].data.n()); + assert!(self.counter as u32 == self.accumulators[0].data.n()); // Copy result GLWE into res GLWE res.copy( module, @@ -177,18 +177,13 @@ impl GLWEPacker { } } -fn pack_core_scratch_space( - module: &Module, - basek: usize, - ct_k: usize, - k_ksk: usize, - digits: usize, - rank: usize, -) -> usize +fn pack_core_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - combine_scratch_space(module, basek, ct_k, k_ksk, digits, rank) + combine_scratch_space(module, out_infos, key_infos) } fn pack_core( @@ -215,11 +210,13 @@ fn pack_core( + VecZnxRshInplace + VecZnxAddInplace + VecZnxNormalizeInplace - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxRotate + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace - + VecZnxBigAutomorphismInplace, + + VecZnxBigSubSmallNegateInplace + + VecZnxBigAutomorphismInplace + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { let log_n: usize = module.log_n(); @@ -271,20 +268,15 @@ fn pack_core( } } -fn combine_scratch_space( - module: &Module, - basek: usize, - ct_k: usize, - k_ksk: usize, - digits: usize, - rank: usize, -) -> usize +fn combine_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::bytes_of(module.n(), basek, ct_k, rank) + GLWECiphertext::alloc_bytes(out_infos) + (GLWECiphertext::rsh_scratch_space(module.n()) - | GLWECiphertext::automorphism_scratch_space(module, basek, ct_k, ct_k, k_ksk, digits, rank)) + | GLWECiphertext::automorphism_inplace_scratch_space(module, out_infos, key_infos)) } /// [combine] merges two ciphertexts together. @@ -312,19 +304,17 @@ fn combine( + VecZnxRshInplace + VecZnxAddInplace + VecZnxNormalizeInplace - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxRotate + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace - + VecZnxBigAutomorphismInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + + VecZnxBigSubSmallNegateInplace + + VecZnxBigAutomorphismInplace + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeGLWECt, { - let n: usize = acc.data.n(); - let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _; + let log_n: usize = acc.data.n().log2(); let a: &mut GLWECiphertext> = &mut acc.data; - let basek: usize = a.basek(); - let k: usize = a.k(); - let rank: usize = a.rank(); let gal_el: i64 = if i == 0 { -1 @@ -346,7 +336,7 @@ fn combine( // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if acc.value { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); // a = a * X^-t a.rotate_inplace(module, -t, scratch_1); @@ -365,7 +355,7 @@ fn combine( if let Some(key) = auto_keys.get(&gal_el) { tmp_b.automorphism_inplace(module, key, scratch_1); } else { - panic!("auto_key[{}] not found", gal_el); + panic!("auto_key[{gal_el}] not found"); } // a = a * X^-t + b - phi(a * X^-t - b) @@ -382,19 +372,19 @@ fn combine( if let Some(key) = auto_keys.get(&gal_el) { a.automorphism_add_inplace(module, key, scratch); } else { - panic!("auto_key[{}] not found", gal_el); + panic!("auto_key[{gal_el}] not found"); } } } else if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); tmp_b.rotate(module, 1 << (log_n - i - 1), b); tmp_b.rsh(module, 1, scratch_1); // a = (b* X^t - phi(b* X^t)) if let Some(key) = auto_keys.get(&gal_el) { - a.automorphism_sub_ba(module, &tmp_b, key, scratch_1); + a.automorphism_sub_negate(module, &tmp_b, key, scratch_1); } else { - panic!("auto_key[{}] not found", gal_el); + panic!("auto_key[{gal_el}] not found"); } acc.value = true; diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 1c5c428..48471ff 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -2,15 +2,19 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxRshInplace, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxNormalizeTmpBytes, VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, + layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx}, }; use crate::{ - layouts::{GLWECiphertext, prepared::GGLWEAutomorphismKeyPrepared}, + TakeGLWECt, + layouts::{ + Base2K, GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWEInfos, + prepared::GGLWEAutomorphismKeyPrepared, + }, operations::GLWEOperations, }; @@ -27,34 +31,38 @@ impl GLWECiphertext> { gal_els } - #[allow(clippy::too_many_arguments)] - pub fn trace_scratch_space( + pub fn trace_scratch_space( module: &Module, - basek: usize, - out_k: usize, - in_k: usize, - ksk_k: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + key_infos: &KEY, ) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + IN: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::automorphism_inplace_scratch_space(module, basek, out_k.min(in_k), ksk_k, digits, rank) + let trace: usize = Self::automorphism_inplace_scratch_space(module, out_infos, key_infos); + if in_infos.base2k() != key_infos.base2k() { + let glwe_conv: usize = VecZnx::alloc_bytes( + module.n(), + (key_infos.rank_out() + 1).into(), + out_infos.k().min(in_infos.k()).div_ceil(key_infos.base2k()) as usize, + ) + module.vec_znx_normalize_tmp_bytes(); + return glwe_conv + trace; + } + + trace } - pub fn trace_inplace_scratch_space( - module: &Module, - basek: usize, - out_k: usize, - ksk_k: usize, - digits: usize, - rank: usize, - ) -> usize + pub fn trace_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::automorphism_inplace_scratch_space(module, basek, out_k, ksk_k, digits, rank) + Self::trace_scratch_space(module, out_infos, out_infos, key_infos) } } @@ -79,8 +87,10 @@ impl GLWECiphertext { + VecZnxBigNormalize + VecZnxBigAutomorphismInplace + VecZnxRshInplace - + VecZnxCopy, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxCopy + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.copy(module, lhs); self.trace_inplace(module, start, end, auto_keys, scratch); @@ -104,23 +114,92 @@ impl GLWECiphertext { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace - + VecZnxRshInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxRshInplace + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { - (start..end).for_each(|i| { - self.rsh(module, 1, scratch); + let basek_ksk: Base2K = auto_keys + .get(auto_keys.keys().next().unwrap()) + .unwrap() + .base2k(); - let p: i64 = if i == 0 { - -1 - } else { - module.galois_element(1 << (i - 1)) - }; - - if let Some(key) = auto_keys.get(&p) { - self.automorphism_add_inplace(module, key, scratch); - } else { - panic!("auto_keys[{}] is empty", p) + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n() as u32); + assert!(start < end); + assert!(end <= module.log_n()); + for key in auto_keys.values() { + assert_eq!(key.n(), module.n() as u32); + assert_eq!(key.base2k(), basek_ksk); + assert_eq!(key.rank_in(), self.rank()); + assert_eq!(key.rank_out(), self.rank()); } - }); + } + + if self.base2k() != basek_ksk { + let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout { + n: module.n().into(), + base2k: basek_ksk, + k: self.k(), + rank: self.rank(), + }); + + for j in 0..(self.rank() + 1).into() { + module.vec_znx_normalize( + basek_ksk.into(), + &mut self_conv.data, + j, + basek_ksk.into(), + &self.data, + j, + scratch_1, + ); + } + + for i in start..end { + self_conv.rsh(module, 1, scratch_1); + + let p: i64 = if i == 0 { + -1 + } else { + module.galois_element(1 << (i - 1)) + }; + + if let Some(key) = auto_keys.get(&p) { + self_conv.automorphism_add_inplace(module, key, scratch_1); + } else { + panic!("auto_keys[{p}] is empty") + } + } + + for j in 0..(self.rank() + 1).into() { + module.vec_znx_normalize( + self.base2k().into(), + &mut self.data, + j, + basek_ksk.into(), + &self_conv.data, + j, + scratch_1, + ); + } + } else { + for i in start..end { + self.rsh(module, 1, scratch); + + let p: i64 = if i == 0 { + -1 + } else { + module.galois_element(1 << (i - 1)) + }; + + if let Some(key) = auto_keys.get(&p) { + self.automorphism_add_inplace(module, key, scratch); + } else { + panic!("auto_keys[{p}] is empty") + } + } + } } } diff --git a/poulpy-core/src/keyswitching/gglwe_ct.rs b/poulpy-core/src/keyswitching/gglwe_ct.rs index 21ea399..346dbdf 100644 --- a/poulpy-core/src/keyswitching/gglwe_ct.rs +++ b/poulpy-core/src/keyswitching/gglwe_ct.rs @@ -1,46 +1,40 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; use crate::layouts::{ - GGLWEAutomorphismKey, GGLWESwitchingKey, GLWECiphertext, Infos, + GGLWEAutomorphismKey, GGLWELayoutInfos, GGLWESwitchingKey, GLWECiphertext, GLWEInfos, prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared}, }; impl GGLWEAutomorphismKey> { - #[allow(clippy::too_many_arguments)] - pub fn keyswitch_scratch_space( + pub fn keyswitch_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + key_infos: &KEY, ) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GGLWELayoutInfos, + IN: GGLWELayoutInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) + GGLWESwitchingKey::keyswitch_scratch_space(module, out_infos, in_infos, key_infos) } - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize + pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GGLWELayoutInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_infos, key_infos) } } @@ -60,8 +54,10 @@ impl GGLWEAutomorphismKey { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.key.keyswitch(module, &lhs.key, rhs, scratch); } @@ -80,43 +76,38 @@ impl GGLWEAutomorphismKey { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft + ScratchAvailable, + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { self.key.keyswitch_inplace(module, &rhs.key, scratch); } } impl GGLWESwitchingKey> { - #[allow(clippy::too_many_arguments)] - pub fn keyswitch_scratch_space( + pub fn keyswitch_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, + out_infos: &OUT, + in_infos: &IN, + key_apply: &KEY, ) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GGLWELayoutInfos, + IN: GGLWELayoutInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) + GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, key_apply) } - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize + pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_apply: &KEY) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GGLWELayoutInfos + GLWEInfos, + KEY: GGLWELayoutInfos + GLWEInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) + GLWECiphertext::keyswitch_inplace_scratch_space(module, out_infos, key_apply) } } @@ -136,8 +127,10 @@ impl GGLWESwitchingKey { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft, + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -168,17 +161,24 @@ impl GGLWESwitchingKey { self.rows(), lhs.rows() ); + assert_eq!( + self.digits(), + lhs.digits(), + "ksk_out digits: {} != ksk_in digits: {}", + self.digits(), + lhs.digits() + ) } - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { + (0..self.rank_in().into()).for_each(|col_i| { + (0..self.rows().into()).for_each(|row_j| { self.at_mut(row_j, col_i) .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch); }); }); - (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { - (0..self.rank_in()).for_each(|col_j| { + (self.rows().min(lhs.rows()).into()..self.rows().into()).for_each(|row_i| { + (0..self.rank_in().into()).for_each(|col_j| { self.at_mut(row_i, col_j).data.zero(); }); }); @@ -198,8 +198,10 @@ impl GGLWESwitchingKey { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft, + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, { #[cfg(debug_assertions)] { @@ -212,8 +214,8 @@ impl GGLWESwitchingKey { ); } - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_j| { + (0..self.rank_in().into()).for_each(|col_i| { + (0..self.rows().into()).for_each(|row_j| { self.at_mut(row_j, col_i) .keyswitch_inplace(module, rhs, scratch) }); diff --git a/poulpy-core/src/keyswitching/ggsw_ct.rs b/poulpy-core/src/keyswitching/ggsw_ct.rs index 6e9d52e..0f05fd1 100644 --- a/poulpy-core/src/keyswitching/ggsw_ct.rs +++ b/poulpy-core/src/keyswitching/ggsw_ct.rs @@ -1,101 +1,115 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, - VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos}, }; use crate::{ layouts::{ - GGLWECiphertext, GGSWCiphertext, GLWECiphertext, Infos, + GGLWECiphertext, GGLWELayoutInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared}, }, operations::GLWEOperations, }; impl GGSWCiphertext> { - pub(crate) fn expand_row_scratch_space( - module: &Module, - basek: usize, - self_k: usize, - k_tsk: usize, - digits: usize, - rank: usize, - ) -> usize + pub(crate) fn expand_row_scratch_space(module: &Module, out_infos: &OUT, tsk_infos: &TSK) -> usize where + OUT: GGSWInfos, + TSK: GGLWELayoutInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes, { - let tsk_size: usize = k_tsk.div_ceil(basek); - let self_size_out: usize = self_k.div_ceil(basek); - let self_size_in: usize = self_size_out.div_ceil(digits); + let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; + let size_in: usize = out_infos + .k() + .div_ceil(tsk_infos.base2k()) + .div_ceil(tsk_infos.digits().into()) as usize; - let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes(rank + 1, tsk_size); - let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, self_size_in); + let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes((tsk_infos.rank_out() + 1).into(), tsk_size); + let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, size_in); let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( - self_size_out, - self_size_in, - self_size_in, - rank, - rank, + tsk_size, + size_in, + size_in, + (tsk_infos.rank_in()).into(), // Verify if rank+1 + (tsk_infos.rank_out()).into(), // Verify if rank+1 tsk_size, ); let tmp_idft: usize = module.vec_znx_big_alloc_bytes(1, tsk_size); let norm: usize = module.vec_znx_normalize_tmp_bytes(); + tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) } #[allow(clippy::too_many_arguments)] - pub fn keyswitch_scratch_space( + pub fn keyswitch_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits_ksk: usize, - k_tsk: usize, - digits_tsk: usize, - rank: usize, + out_infos: &OUT, + in_infos: &IN, + apply_infos: &KEY, + tsk_infos: &TSK, ) -> usize where + OUT: GGSWInfos, + IN: GGSWInfos, + KEY: GGLWELayoutInfos, + TSK: GGLWELayoutInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { - let out_size: usize = k_out.div_ceil(basek); - let res_znx: usize = VecZnx::alloc_bytes(module.n(), rank + 1, out_size); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); - let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); - let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); - res_znx + ci_dft + (ks | expand_rows | res_dft) + #[cfg(debug_assertions)] + { + assert_eq!(apply_infos.rank_in(), apply_infos.rank_out()); + assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); + assert_eq!(apply_infos.rank_in(), tsk_infos.rank_in()); + } + + let rank: usize = apply_infos.rank_out().into(); + + let size_out: usize = out_infos.k().div_ceil(out_infos.base2k()) as usize; + let res_znx: usize = VecZnx::alloc_bytes(module.n(), rank + 1, size_out); + let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out); + let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, apply_infos); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, size_out); + + if in_infos.base2k() == tsk_infos.base2k() { + res_znx + ci_dft + (ks | expand_rows | res_dft) + } else { + let a_conv: usize = VecZnx::alloc_bytes( + module.n(), + 1, + out_infos.k().div_ceil(tsk_infos.base2k()) as usize, + ) + module.vec_znx_normalize_tmp_bytes(); + res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft) + } } #[allow(clippy::too_many_arguments)] - pub fn keyswitch_inplace_scratch_space( + pub fn keyswitch_inplace_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits_ksk: usize, - k_tsk: usize, - digits_tsk: usize, - rank: usize, + out_infos: &OUT, + apply_infos: &KEY, + tsk_infos: &TSK, ) -> usize where + OUT: GGSWInfos, + KEY: GGLWELayoutInfos, + TSK: GGLWELayoutInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { - GGSWCiphertext::keyswitch_scratch_space( - module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, - ) + GGSWCiphertext::keyswitch_scratch_space(module, out_infos, out_infos, apply_infos, tsk_infos) } } @@ -120,18 +134,21 @@ impl GGSWCiphertext { + VmpApplyDftToDftAdd + VecZnxDftAddInplace + VecZnxBigNormalize - + VecZnxIdftApplyTmpA, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, + + VecZnxIdftApplyTmpA + + VecZnxNormalize, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, { #[cfg(debug_assertions)] { - assert_eq!(self.rank(), a.rank()); + use crate::layouts::{GLWEInfos, LWEInfos}; + + assert_eq!(self.rank(), a.rank_out()); assert_eq!(self.rows(), a.rows()); - assert_eq!(self.n(), module.n()); - assert_eq!(a.n(), module.n()); - assert_eq!(tsk.n(), module.n()); + assert_eq!(self.n(), module.n() as u32); + assert_eq!(a.n(), module.n() as u32); + assert_eq!(tsk.n(), module.n() as u32); } - (0..self.rows()).for_each(|row_i| { + (0..self.rows().into()).for_each(|row_i| { self.at_mut(row_i, 0).copy(module, &a.at(row_i, 0)); }); self.expand_row(module, tsk, scratch); @@ -159,10 +176,11 @@ impl GGSWCiphertext { + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, + + VecZnxIdftApplyTmpA + + VecZnxNormalize, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, { - (0..lhs.rows()).for_each(|row_i| { + (0..lhs.rows().into()).for_each(|row_i| { // Key-switch column 0, i.e. // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) self.at_mut(row_i, 0) @@ -192,10 +210,11 @@ impl GGSWCiphertext { + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace - + VecZnxIdftApplyTmpA, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, + + VecZnxIdftApplyTmpA + + VecZnxNormalize, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, { - (0..self.rows()).for_each(|row_i| { + (0..self.rows().into()).for_each(|row_i| { // Key-switch column 0, i.e. // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) self.at_mut(row_i, 0) @@ -220,34 +239,41 @@ impl GGSWCiphertext { + VmpApplyDftToDftAdd + VecZnxDftAddInplace + VecZnxBigNormalize - + VecZnxIdftApplyTmpA, - Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, + + VecZnxIdftApplyTmpA + + VecZnxNormalize, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnx, { - assert!( - scratch.available() - >= GGSWCiphertext::expand_row_scratch_space( - module, - self.basek(), - self.k(), - tsk.k(), - tsk.digits(), - tsk.rank() - ) - ); + let basek_in: usize = self.base2k().into(); + let basek_tsk: usize = tsk.base2k().into(); - let n: usize = self.n(); - let rank: usize = self.rank(); + assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self, tsk)); + + let n: usize = self.n().into(); + let rank: usize = self.rank().into(); let cols: usize = rank + 1; - // Keyswitch the j-th row of the col 0 - (0..self.rows()).for_each(|row_i| { - // Pre-compute DFT of (a0, a1, a2) - let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, self.size()); - (0..cols).for_each(|i| { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); - }); + let a_size: usize = (self.size() * basek_in).div_ceil(basek_tsk); - (1..cols).for_each(|col_j| { + // Keyswitch the j-th row of the col 0 + for row_i in 0..self.rows().into() { + let a = &self.at(row_i, 0).data; + + // Pre-compute DFT of (a0, a1, a2) + let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, a_size); + + if basek_in == basek_tsk { + for i in 0..cols { + module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(n, 1, a_size); + for i in 0..cols { + module.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); + module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); + } + } + + for col_j in 1..cols { // Example for rank 3: // // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is @@ -268,7 +294,7 @@ impl GGSWCiphertext { // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - let digits: usize = tsk.digits(); + let digits: usize = tsk.digits().into(); let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size()); let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits)); @@ -285,11 +311,11 @@ impl GGSWCiphertext { // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) // = // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - (1..cols).for_each(|col_i| { + for col_i in 1..cols { let pmat: &VmpPMat = &tsk.at(col_i - 1, col_j - 1).key.data; // Selects Enc(s[i]s[j]) // Extracts a[i] and multipies with Enc(s[i]s[j]) - (0..digits).for_each(|di| { + for di in 0..digits { tmp_a.set_size((ci_dft.size() + di) / digits); // Small optimization for digits > 2 @@ -307,8 +333,8 @@ impl GGSWCiphertext { } else { module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); } - }); - }); + } + } } // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i @@ -322,18 +348,19 @@ impl GGSWCiphertext { // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size()); - (0..cols).for_each(|i| { + for i in 0..cols { module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); module.vec_znx_big_normalize( - self.basek(), + basek_in, &mut self.at_mut(row_i, col_j).data, i, + basek_tsk, &tmp_idft, 0, scratch_3, ); - }); - }) - }) + } + } + } } } diff --git a/poulpy-core/src/keyswitching/glwe_ct.rs b/poulpy-core/src/keyswitching/glwe_ct.rs index 14e23e0..9174f84 100644 --- a/poulpy-core/src/keyswitching/glwe_ct.rs +++ b/poulpy-core/src/keyswitching/glwe_ct.rs @@ -1,52 +1,64 @@ use poulpy_hal::{ api::{ - ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos}, }; -use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWESwitchingKeyPrepared}; +use crate::layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWESwitchingKeyPrepared}; impl GLWECiphertext> { - #[allow(clippy::too_many_arguments)] - pub fn keyswitch_scratch_space( + pub fn keyswitch_scratch_space( module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, + out_infos: &OUT, + in_infos: &IN, + key_apply: &KEY, ) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + IN: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); - let out_size: usize = k_out.div_ceil(basek); - let ksk_size: usize = k_ksk.div_ceil(basek); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank_out + 1, ksk_size); // TODO OPTIMIZE - let ai_dft: usize = module.vec_znx_dft_alloc_bytes(rank_in, in_size); - let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) - + module.vec_znx_dft_alloc_bytes(rank_in, in_size); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); - res_dft + ((ai_dft + vmp) | normalize) + let in_size: usize = in_infos + .k() + .div_ceil(key_apply.base2k()) + .div_ceil(key_apply.digits().into()) as usize; + let out_size: usize = out_infos.size(); + let ksk_size: usize = key_apply.size(); + let res_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE + let ai_dft: usize = module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size); + let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes( + out_size, + in_size, + in_size, + (key_apply.rank_in()).into(), + (key_apply.rank_out() + 1).into(), + ksk_size, + ) + module.vec_znx_dft_alloc_bytes((key_apply.rank_in()).into(), in_size); + let normalize_big: usize = module.vec_znx_big_normalize_tmp_bytes(); + if in_infos.base2k() == key_apply.base2k() { + res_dft + ((ai_dft + vmp) | normalize_big) + } else if key_apply.digits() == 1 { + // In this case, we only need one column, temporary, that we can drop once a_dft is computed. + let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), 1, in_size) + module.vec_znx_normalize_tmp_bytes(); + res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) + } else { + // Since we stride over a to get a_dft when digits > 1, we need to store the full columns of a with in the base conversion. + let normalize_conv: usize = VecZnx::alloc_bytes(module.n(), (key_apply.rank_in()).into(), in_size); + res_dft + ((ai_dft + normalize_conv + (module.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) + } } - pub fn keyswitch_inplace_scratch_space( - module: &Module, - basek: usize, - k_out: usize, - k_ksk: usize, - digits: usize, - rank: usize, - ) -> usize + pub fn keyswitch_inplace_scratch_space(module: &Module, out_infos: &OUT, key_apply: &KEY) -> usize where - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + OUT: GLWEInfos, + KEY: GGLWELayoutInfos, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, { - Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) + Self::keyswitch_scratch_space(module, out_infos, out_infos, key_apply) } } @@ -61,10 +73,9 @@ impl GLWECiphertext { ) where DataLhs: DataRef, DataRhs: DataRef, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, Scratch: ScratchAvailable, { - let basek: usize = self.basek(); assert_eq!( lhs.rank(), rhs.rank_in(), @@ -79,43 +90,26 @@ impl GLWECiphertext { self.rank(), rhs.rank_out() ); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); assert_eq!(rhs.n(), self.n()); assert_eq!(lhs.n(), self.n()); + + let scrach_needed: usize = GLWECiphertext::keyswitch_scratch_space(module, self, lhs, rhs); + assert!( - scratch.available() - >= GLWECiphertext::keyswitch_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ), + scratch.available() >= scrach_needed, "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space( module, - self.basek(), + self.base2k(), self.k(), + lhs.base2k(), lhs.k(), + rhs.base2k(), rhs.k(), rhs.digits(), rhs.rank_in(), rhs.rank_out(), - )={}", + )={scrach_needed}", scratch.available(), - GLWECiphertext::keyswitch_scratch_space( - module, - self.basek(), - self.k(), - lhs.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ) ); } @@ -127,10 +121,9 @@ impl GLWECiphertext { scratch: &Scratch, ) where DataRhs: DataRef, - Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, Scratch: ScratchAvailable, { - let basek: usize = self.basek(); assert_eq!( self.rank(), rhs.rank_out(), @@ -138,41 +131,15 @@ impl GLWECiphertext { self.rank(), rhs.rank_out() ); - assert_eq!(self.basek(), basek); + assert_eq!(rhs.n(), self.n()); + + let scrach_needed: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, self, rhs); + assert!( - scratch.available() - >= GLWECiphertext::keyswitch_scratch_space( - module, - self.basek(), - self.k(), - self.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ), - "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space( - module, - self.basek(), - self.k(), - self.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - )={}", + scratch.available() >= scrach_needed, + "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space()={scrach_needed}", scratch.available(), - GLWECiphertext::keyswitch_scratch_space( - module, - self.basek(), - self.k(), - self.k(), - rhs.k(), - rhs.digits(), - rhs.rank_in(), - rhs.rank_out(), - ) ); } } @@ -181,7 +148,7 @@ impl GLWECiphertext { pub fn keyswitch( &mut self, module: &Module, - lhs: &GLWECiphertext, + glwe_in: &GLWECiphertext, rhs: &GGLWESwitchingKeyPrepared, scratch: &mut Scratch, ) where @@ -193,17 +160,31 @@ impl GLWECiphertext { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft, + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, { #[cfg(debug_assertions)] { - self.assert_keyswitch(module, lhs, rhs, scratch); + self.assert_keyswitch(module, glwe_in, rhs, scratch); } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise - let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch_1); - (0..self.cols()).for_each(|i| { - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + + let basek_out: usize = self.base2k().into(); + let basek_ksk: usize = rhs.base2k().into(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise + let res_big: VecZnxBig<_, B> = glwe_in.keyswitch_internal(module, res_dft, rhs, scratch_1); + (0..(self.rank() + 1).into()).for_each(|i| { + module.vec_znx_big_normalize( + basek_out, + &mut self.data, + i, + basek_ksk, + &res_big, + i, + scratch_1, + ); }) } @@ -222,17 +203,31 @@ impl GLWECiphertext { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft, + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes, + Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnx, { #[cfg(debug_assertions)] { self.assert_keyswitch_inplace(module, rhs, scratch); } - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise + + let basek_in: usize = self.base2k().into(); + let basek_ksk: usize = rhs.base2k().into(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // Todo optimise let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1); - (0..self.cols()).for_each(|i| { - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + (0..(self.rank() + 1).into()).for_each(|i| { + module.vec_znx_big_normalize( + basek_in, + &mut self.data, + i, + basek_ksk, + &res_big, + i, + scratch_1, + ); }) } } @@ -257,19 +252,30 @@ impl GLWECiphertext { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeVecZnxDft, + + VecZnxBigNormalize + + VecZnxNormalize, + Scratch: TakeVecZnxDft + TakeVecZnx, { if rhs.digits() == 1 { - return keyswitch_vmp_one_digit(module, res_dft, &self.data, &rhs.key.data, scratch); + return keyswitch_vmp_one_digit( + module, + self.base2k().into(), + rhs.base2k().into(), + res_dft, + &self.data, + &rhs.key.data, + scratch, + ); } keyswitch_vmp_multiple_digits( module, + self.base2k().into(), + rhs.base2k().into(), res_dft, &self.data, &rhs.key.data, - rhs.digits(), + rhs.digits().into(), scratch, ) } @@ -277,6 +283,8 @@ impl GLWECiphertext { fn keyswitch_vmp_one_digit( module: &Module, + basek_in: usize, + basek_ksk: usize, mut res_dft: VecZnxDft, a: &VecZnx, mat: &VmpPMat, @@ -286,23 +294,42 @@ where DataRes: DataMut, DataIn: DataRef, DataVmp: DataRef, - Module: - VecZnxDftAllocBytes + VecZnxDftApply + VmpApplyDftToDft + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace, - Scratch: TakeVecZnxDft, + Module: VecZnxDftAllocBytes + + VecZnxDftApply + + VmpApplyDftToDft + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxNormalize, + Scratch: TakeVecZnxDft + TakeVecZnx, { let cols: usize = a.cols(); + + let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk); let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size()); - (0..cols - 1).for_each(|col_i| { - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1); - }); + + if basek_in == basek_ksk { + (0..cols - 1).for_each(|col_i| { + module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1); + }); + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), 1, a_size); + (0..cols - 1).for_each(|col_i| { + module.vec_znx_normalize(basek_ksk, &mut a_conv, 0, basek_in, a, col_i + 1, scratch_2); + module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); + }); + } + module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); res_big } +#[allow(clippy::too_many_arguments)] fn keyswitch_vmp_multiple_digits( module: &Module, + basek_in: usize, + basek_ksk: usize, mut res_dft: VecZnxDft, a: &VecZnx, mat: &VmpPMat, @@ -318,37 +345,67 @@ where + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace, - Scratch: TakeVecZnxDft, + + VecZnxBigAddSmallInplace + + VecZnxNormalize, + Scratch: TakeVecZnxDft + TakeVecZnx, { let cols: usize = a.cols(); - let size: usize = a.size(); - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits)); - + let a_size: usize = (a.size() * basek_in).div_ceil(basek_ksk); + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a_size.div_ceil(digits)); ai_dft.data_mut().fill(0); - (0..digits).for_each(|di| { - ai_dft.set_size((size + di) / digits); + if basek_in == basek_ksk { + for di in 0..digits { + ai_dft.set_size((a_size + di) / digits); - // Small optimization for digits > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. - // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last digits-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize); + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize); - (0..cols - 1).for_each(|col_i| { - module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1); - }); + for j in 0..cols - 1 { + module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, j, a, j + 1); + } - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); - } else { - module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1); + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); + } else { + module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1); + } } - }); + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(a.n(), cols - 1, a_size); + for j in 0..cols - 1 { + module.vec_znx_normalize(basek_ksk, &mut a_conv, j, basek_in, a, j + 1, scratch_2); + } + + for di in 0..digits { + ai_dft.set_size((a_size + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize); + + for j in 0..cols - 1 { + module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, j, &a_conv, j); + } + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_2); + } else { + module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_2); + } + } + } res_dft.set_size(res_dft.max_size()); let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); diff --git a/poulpy-core/src/keyswitching/lwe_ct.rs b/poulpy-core/src/keyswitching/lwe_ct.rs index 7588cb7..2dc9364 100644 --- a/poulpy-core/src/keyswitching/lwe_ct.rs +++ b/poulpy-core/src/keyswitching/lwe_ct.rs @@ -1,26 +1,31 @@ use poulpy_hal::{ api::{ ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; use crate::{ TakeGLWECt, - layouts::{GLWECiphertext, Infos, LWECiphertext, prepared::LWESwitchingKeyPrepared}, + layouts::{ + GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, LWECiphertext, LWEInfos, Rank, TorusPrecision, + prepared::LWESwitchingKeyPrepared, + }, }; impl LWECiphertext> { - pub fn keyswitch_scratch_space( + pub fn keyswitch_scratch_space( module: &Module, - basek: usize, - k_lwe_out: usize, - k_lwe_in: usize, - k_ksk: usize, + out_infos: &OUT, + in_infos: &IN, + key_infos: &KEY, ) -> usize where + OUT: LWEInfos, + IN: LWEInfos, + KEY: GGLWELayoutInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes @@ -30,10 +35,30 @@ impl LWECiphertext> { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes, { - GLWECiphertext::bytes_of(module.n(), basek, k_lwe_out.max(k_lwe_in), 1) - + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1) + let max_k: TorusPrecision = in_infos.k().max(out_infos.k()); + + let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: module.n().into(), + base2k: in_infos.base2k(), + k: max_k, + rank: Rank(1), + }; + + let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: module.n().into(), + base2k: out_infos.base2k(), + k: max_k, + rank: Rank(1), + }; + + let glwe_in: usize = GLWECiphertext::alloc_bytes(&glwe_in_infos); + let glwe_out: usize = GLWECiphertext::alloc_bytes(&glwe_out_infos); + let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, key_infos); + + glwe_in + glwe_out + ks } } @@ -55,32 +80,47 @@ impl LWECiphertext { + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, + + VecZnxBigNormalize + + VecZnxNormalize + + VecZnxNormalizeTmpBytes + + VecZnxCopy, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { #[cfg(debug_assertions)] { - assert!(self.n() <= module.n()); - assert!(a.n() <= module.n()); - assert_eq!(self.basek(), a.basek()); + assert!(self.n() <= module.n() as u32); + assert!(a.n() <= module.n() as u32); + assert!(scratch.available() >= LWECiphertext::keyswitch_scratch_space(module, self, a, ksk)); } - let max_k: usize = self.k().max(a.k()); - let basek: usize = self.basek(); + let max_k: TorusPrecision = self.k().max(a.k()); - let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1); - glwe.data.zero(); + let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize; - let n_lwe: usize = a.n(); + let (mut glwe_in, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout { + n: ksk.n(), + base2k: a.base2k(), + k: max_k, + rank: Rank(1), + }); + glwe_in.data.zero(); - (0..a.size()).for_each(|i| { - let data_lwe: &[i64] = a.data.at(0, i); - glwe.data.at_mut(0, i)[0] = data_lwe[0]; - glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + let (mut glwe_out, scratch_1) = scratch_1.take_glwe_ct(&GLWECiphertextLayout { + n: ksk.n(), + base2k: self.base2k(), + k: max_k, + rank: Rank(1), }); - glwe.keyswitch_inplace(module, &ksk.0, scratch_1); + let n_lwe: usize = a.n().into(); - self.sample_extract(&glwe); + for i in 0..a_size { + let data_lwe: &[i64] = a.data.at(0, i); + glwe_in.data.at_mut(0, i)[0] = data_lwe[0]; + glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); + } + + glwe_out.keyswitch(module, &glwe_in, &ksk.0, scratch_1); + self.sample_extract(&glwe_out); } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_atk.rs b/poulpy-core/src/layouts/compressed/gglwe_atk.rs index 8c4c8c9..4ad55b5 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_atk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_atk.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - GGLWEAutomorphismKey, Infos, + Base2K, Degree, Digits, GGLWEAutomorphismKey, GGLWELayoutInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, compressed::{Decompress, GGLWESwitchingKeyCompressed}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -17,9 +17,50 @@ pub struct GGLWEAutomorphismKeyCompressed { pub(crate) p: i64, } +impl LWEInfos for GGLWEAutomorphismKeyCompressed { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} +impl GLWEInfos for GGLWEAutomorphismKeyCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWEAutomorphismKeyCompressed { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn digits(&self) -> Digits { + self.key.digits() + } + + fn rows(&self) -> Rows { + self.key.rows() + } +} + impl fmt::Debug for GGLWEAutomorphismKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -29,16 +70,6 @@ impl FillUniform for GGLWEAutomorphismKeyCompressed { } } -impl Reset for GGLWEAutomorphismKeyCompressed -where - MatZnx: Reset, -{ - fn reset(&mut self) { - self.key.reset(); - self.p = 0; - } -} - impl fmt::Display for GGLWEAutomorphismKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(AutomorphismKeyCompressed: p={}) {}", self.p, self.key) @@ -46,49 +77,34 @@ impl fmt::Display for GGLWEAutomorphismKeyCompressed { } impl GGLWEAutomorphismKeyCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { - GGLWEAutomorphismKeyCompressed { - key: GGLWESwitchingKeyCompressed::alloc(n, basek, k, rows, digits, rank, rank), + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + debug_assert_eq!(infos.rank_in(), infos.rank_out()); + Self { + key: GGLWESwitchingKeyCompressed::alloc(infos), p: 0, } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - GGLWESwitchingKeyCompressed::>::bytes_of(n, basek, k, rows, digits, rank) - } -} - -impl Infos for GGLWEAutomorphismKeyCompressed { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.key.inner() + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self { + Self { + key: GGLWESwitchingKeyCompressed::alloc_with(n, base2k, k, rows, digits, rank, rank), + p: 0, + } } - fn basek(&self) -> usize { - self.key.basek() + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + debug_assert_eq!(infos.rank_in(), infos.rank_out()); + GGLWESwitchingKeyCompressed::alloc_bytes(infos) } - fn k(&self) -> usize { - self.key.k() - } -} - -impl GGLWEAutomorphismKeyCompressed { - pub fn rank(&self) -> usize { - self.key.rank() - } - - pub fn digits(&self) -> usize { - self.key.digits() - } - - pub fn rank_in(&self) -> usize { - self.key.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.key.rank_out() + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize { + GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rows, digits, rank, rank) } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_ct.rs b/poulpy-core/src/layouts/compressed/gglwe_ct.rs index 0f6db41..a3dd256 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ct.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos}, source::Source, }; use crate::layouts::{ - GGLWECiphertext, Infos, + Base2K, Degree, Digits, GGLWECiphertext, GGLWELayoutInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, compressed::{Decompress, GLWECiphertextCompressed}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -14,16 +14,57 @@ use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct GGLWECiphertextCompressed { pub(crate) data: MatZnx, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) rank_out: usize, - pub(crate) digits: usize, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, + pub(crate) rank_out: Rank, + pub(crate) digits: Digits, pub(crate) seed: Vec<[u8; 32]>, } +impl LWEInfos for GGLWECiphertextCompressed { + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } +} +impl GLWEInfos for GGLWECiphertextCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWECiphertextCompressed { + fn rank_in(&self) -> Rank { + Rank(self.data.cols_in() as u32) + } + + fn rank_out(&self) -> Rank { + self.rank_out + } + + fn digits(&self) -> Digits { + self.digits + } + + fn rows(&self) -> Rows { + Rows(self.data.rows() as u32) + } +} + impl fmt::Debug for GGLWECiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -33,133 +74,140 @@ impl FillUniform for GGLWECiphertextCompressed { } } -impl Reset for GGLWECiphertextCompressed -where - MatZnx: Reset, -{ - fn reset(&mut self) { - self.data.reset(); - self.basek = 0; - self.k = 0; - self.digits = 0; - self.rank_out = 0; - self.seed = Vec::new(); - } -} - impl fmt::Display for GGLWECiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGLWECiphertextCompressed: basek={} k={} digits={}) {}", - self.basek, self.k, self.digits, self.data + "(GGLWECiphertextCompressed: base2k={} k={} digits={}) {}", + self.base2k.0, self.k.0, self.digits.0, self.data ) } } impl GGLWECiphertextCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { - let size: usize = k.div_ceil(basek); + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + Self::alloc_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank_in(), + infos.rank_out(), + ) + } + + pub fn alloc_with( + n: Degree, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, + ) -> Self { + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid gglwe: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid gglwe: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); Self { - data: MatZnx::alloc(n, rows, rank_in, 1, size), - basek, + data: MatZnx::alloc( + n.into(), + rows.into(), + rank_in.into(), + 1, + k.0.div_ceil(base2k.0) as usize, + ), k, - rank_out, + base2k, digits, - seed: vec![[0u8; 32]; rows * rank_in], + rank_out, + seed: vec![[0u8; 32]; (rows.0 * rank_in.0) as usize], } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize { - let size: usize = k.div_ceil(basek); + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + Self::alloc_bytes_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank_in(), + infos.rank_out(), + ) + } + + pub fn alloc_bytes_with( + n: Degree, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + _rank_out: Rank, + ) -> usize { + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid gglwe: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid gglwe: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); - MatZnx::alloc_bytes(n, rows, rank_in, 1, rows) - } -} - -impl Infos for GGLWECiphertextCompressed { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGLWECiphertextCompressed { - pub fn rank(&self) -> usize { - self.rank_out - } - - pub fn digits(&self) -> usize { - self.digits - } - - pub fn rank_in(&self) -> usize { - self.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.rank_out + MatZnx::alloc_bytes( + n.into(), + rows.into(), + rank_in.into(), + 1, + k.0.div_ceil(base2k.0) as usize, + ) } } impl GGLWECiphertextCompressed { pub(crate) fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> { + let rank_in: usize = self.rank_in().into(); GLWECiphertextCompressed { data: self.data.at(row, col), - basek: self.basek, k: self.k, + base2k: self.base2k, rank: self.rank_out, - seed: self.seed[self.rank_in() * row + col], + seed: self.seed[rank_in * row + col], } } } impl GGLWECiphertextCompressed { pub(crate) fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> { - let rank_in: usize = self.rank_in(); + let rank_in: usize = self.rank_in().into(); GLWECiphertextCompressed { - data: self.data.at_mut(row, col), - basek: self.basek, k: self.k, + base2k: self.base2k, rank: self.rank_out, + data: self.data.at_mut(row, col), seed: self.seed[rank_in * row + col], // Warning: value is copied and not borrow mut } } @@ -167,12 +215,12 @@ impl GGLWECiphertextCompressed { impl ReaderFrom for GGLWECiphertextCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; - self.digits = reader.read_u64::()? as usize; - self.rank_out = reader.read_u64::()? as usize; - let seed_len = reader.read_u64::()? as usize; - self.seed = vec![[0u8; 32]; seed_len]; + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + self.digits = Digits(reader.read_u32::()?); + self.rank_out = Rank(reader.read_u32::()?); + let seed_len: u32 = reader.read_u32::()?; + self.seed = vec![[0u8; 32]; seed_len as usize]; for s in &mut self.seed { reader.read_exact(s)?; } @@ -182,11 +230,11 @@ impl ReaderFrom for GGLWECiphertextCompressed { impl WriterTo for GGLWECiphertextCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; - writer.write_u64::(self.digits as u64)?; - writer.write_u64::(self.rank_out as u64)?; - writer.write_u64::(self.seed.len() as u64)?; + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; + writer.write_u32::(self.digits.into())?; + writer.write_u32::(self.rank_out.into())?; + writer.write_u32::(self.seed.len() as u32)?; for s in &self.seed { writer.write_all(s)?; } @@ -201,14 +249,12 @@ where fn decompress(&mut self, module: &Module, other: &GGLWECiphertextCompressed) { #[cfg(debug_assertions)] { - use poulpy_hal::layouts::ZnxInfos; - assert_eq!( self.n(), - other.data.n(), + other.n(), "invalid receiver: self.n()={} != other.n()={}", self.n(), - other.data.n() + other.n() ); assert_eq!( self.size(), @@ -241,8 +287,8 @@ where ); } - let rank_in: usize = self.rank_in(); - let rows: usize = self.rows(); + let rank_in: usize = self.rank_in().into(); + let rows: usize = self.rows().into(); (0..rank_in).for_each(|col_i| { (0..rows).for_each(|row_i| { diff --git a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs index 72070b3..fd25dcd 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - GGLWESwitchingKey, Infos, + Base2K, Degree, Digits, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, compressed::{Decompress, GGLWECiphertextCompressed}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -18,9 +18,50 @@ pub struct GGLWESwitchingKeyCompressed { pub(crate) sk_out_n: usize, // Degree of sk_out } +impl LWEInfos for GGLWESwitchingKeyCompressed { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} +impl GLWEInfos for GGLWESwitchingKeyCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWESwitchingKeyCompressed { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn digits(&self) -> Digits { + self.key.digits() + } + + fn rows(&self) -> Rows { + self.key.rows() + } +} + impl fmt::Debug for GGLWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -30,17 +71,6 @@ impl FillUniform for GGLWESwitchingKeyCompressed { } } -impl Reset for GGLWESwitchingKeyCompressed -where - MatZnx: Reset, -{ - fn reset(&mut self) { - self.key.reset(); - self.sk_in_n = 0; - self.sk_out_n = 0; - } -} - impl fmt::Display for GGLWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( @@ -51,51 +81,51 @@ impl fmt::Display for GGLWESwitchingKeyCompressed { } } -impl Infos for GGLWESwitchingKeyCompressed { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.key.inner() - } - - fn basek(&self) -> usize { - self.key.basek() - } - - fn k(&self) -> usize { - self.key.k() - } -} - -impl GGLWESwitchingKeyCompressed { - pub fn rank(&self) -> usize { - self.key.rank() - } - - pub fn digits(&self) -> usize { - self.key.digits() - } - - pub fn rank_in(&self) -> usize { - self.key.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.key.rank_out() - } -} - impl GGLWESwitchingKeyCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { GGLWESwitchingKeyCompressed { - key: GGLWECiphertextCompressed::alloc(n, basek, k, rows, digits, rank_in, rank_out), + key: GGLWECiphertextCompressed::alloc(infos), sk_in_n: 0, sk_out_n: 0, } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize { - GGLWECiphertextCompressed::bytes_of(n, basek, k, rows, digits, rank_in) + pub fn alloc_with( + n: Degree, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, + ) -> Self { + GGLWESwitchingKeyCompressed { + key: GGLWECiphertextCompressed::alloc_with(n, base2k, k, rows, digits, rank_in, rank_out), + sk_in_n: 0, + sk_out_n: 0, + } + } + + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + GGLWECiphertextCompressed::alloc_bytes(infos) + } + + pub fn alloc_bytes_with( + n: Degree, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, + ) -> usize { + GGLWECiphertextCompressed::alloc_bytes_with(n, base2k, k, rows, digits, rank_in, rank_out) } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs index 08917cd..2c5e102 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - GGLWETensorKey, Infos, + Base2K, Degree, Digits, GGLWELayoutInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, compressed::{Decompress, GGLWESwitchingKeyCompressed}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -16,9 +16,49 @@ pub struct GGLWETensorKeyCompressed { pub(crate) keys: Vec>, } +impl LWEInfos for GGLWETensorKeyCompressed { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + fn size(&self) -> usize { + self.keys[0].size() + } +} +impl GLWEInfos for GGLWETensorKeyCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWETensorKeyCompressed { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn digits(&self) -> Digits { + self.keys[0].digits() + } + + fn rows(&self) -> Rows { + self.keys[0].rows() + } +} + impl fmt::Debug for GGLWETensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -30,76 +70,79 @@ impl FillUniform for GGLWETensorKeyCompressed { } } -impl Reset for GGLWETensorKeyCompressed -where - MatZnx: Reset, -{ - fn reset(&mut self) { - self.keys - .iter_mut() - .for_each(|key: &mut GGLWESwitchingKeyCompressed| key.reset()) - } -} - impl fmt::Display for GGLWETensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKeyCompressed)",)?; for (i, key) in self.keys.iter().enumerate() { - write!(f, "{}: {}", i, key)?; + write!(f, "{i}: {key}")?; } Ok(()) } } impl GGLWETensorKeyCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWETensorKeyCompressed" + ); + Self::alloc_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank_out(), + ) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self { let mut keys: Vec>> = Vec::new(); - let pairs: usize = (((rank + 1) * rank) >> 1).max(1); + let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKeyCompressed::alloc( - n, basek, k, rows, digits, 1, rank, + keys.push(GGLWESwitchingKeyCompressed::alloc_with( + n, + base2k, + k, + rows, + digits, + Rank(1), + rank, )); }); Self { keys } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GGLWESwitchingKeyCompressed::bytes_of(n, basek, k, rows, digits, 1) - } -} - -impl Infos for GGLWETensorKeyCompressed { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.keys[0].inner() + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWETensorKeyCompressed" + ); + let rank_out: usize = infos.rank_out().into(); + let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); + pairs + * GGLWESwitchingKeyCompressed::alloc_bytes_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + Rank(1), + infos.rank_out(), + ) } - fn basek(&self) -> usize { - self.keys[0].basek() - } - - fn k(&self) -> usize { - self.keys[0].k() - } -} - -impl GGLWETensorKeyCompressed { - pub fn rank(&self) -> usize { - self.keys[0].rank() - } - - pub fn digits(&self) -> usize { - self.keys[0].digits() - } - - pub fn rank_in(&self) -> usize { - self.keys[0].rank_in() - } - - pub fn rank_out(&self) -> usize { - self.keys[0].rank_out() + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize { + let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; + pairs * GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rows, digits, Rank(1), rank) } } @@ -134,7 +177,7 @@ impl GGLWETensorKeyCompressed { if i > j { std::mem::swap(&mut i, &mut j); }; - let rank: usize = self.rank(); + let rank: usize = self.rank_out().into(); &mut self.keys[i * rank + j - (i * (i + 1) / 2)] } } diff --git a/poulpy-core/src/layouts/compressed/ggsw_ct.rs b/poulpy-core/src/layouts/compressed/ggsw_ct.rs index 5cb9d2d..93443e4 100644 --- a/poulpy-core/src/layouts/compressed/ggsw_ct.rs +++ b/poulpy-core/src/layouts/compressed/ggsw_ct.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos}, source::Source, }; use crate::layouts::{ - GGSWCiphertext, Infos, + Base2K, Degree, Digits, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, compressed::{Decompress, GLWECiphertextCompressed}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -14,13 +14,45 @@ use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct GGSWCiphertextCompressed { pub(crate) data: MatZnx, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) digits: usize, - pub(crate) rank: usize, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, + pub(crate) digits: Digits, + pub(crate) rank: Rank, pub(crate) seed: Vec<[u8; 32]>, } +impl LWEInfos for GGSWCiphertextCompressed { + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + fn size(&self) -> usize { + self.data.size() + } +} +impl GLWEInfos for GGSWCiphertextCompressed { + fn rank(&self) -> Rank { + self.rank + } +} + +impl GGSWInfos for GGSWCiphertextCompressed { + fn digits(&self) -> Digits { + self.digits + } + + fn rows(&self) -> Rows { + Rows(self.data.rows() as u32) + } +} + impl fmt::Debug for GGSWCiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.data) @@ -31,23 +63,12 @@ impl fmt::Display for GGSWCiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGSWCiphertextCompressed: basek={} k={} digits={}) {}", - self.basek, self.k, self.digits, self.data + "(GGSWCiphertextCompressed: base2k={} k={} digits={}) {}", + self.base2k, self.k, self.digits, self.data ) } } -impl Reset for GGSWCiphertextCompressed { - fn reset(&mut self) { - self.data.reset(); - self.basek = 0; - self.k = 0; - self.digits = 0; - self.rank = 0; - self.seed = Vec::new(); - } -} - impl FillUniform for GGSWCiphertextCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); @@ -55,114 +76,123 @@ impl FillUniform for GGSWCiphertextCompressed { } impl GGSWCiphertextCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { - let size: usize = k.div_ceil(basek); - debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); + pub fn alloc(infos: &A) -> Self + where + A: GGSWInfos, + { + Self::alloc_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank(), + ) + } + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self { + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid ggsw: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid ggsw: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); Self { - data: MatZnx::alloc(n, rows, rank + 1, 1, k.div_ceil(basek)), - basek, + data: MatZnx::alloc( + n.into(), + rows.into(), + (rank + 1).into(), + 1, + k.0.div_ceil(base2k.0) as usize, + ), k, + base2k, digits, rank, seed: Vec::new(), } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - let size: usize = k.div_ceil(basek); + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGSWInfos, + { + Self::alloc_bytes_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank(), + ) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize { + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid ggsw: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid ggsw: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); - MatZnx::alloc_bytes(n, rows, rank + 1, 1, size) + MatZnx::alloc_bytes( + n.into(), + rows.into(), + (rank + 1).into(), + 1, + k.0.div_ceil(base2k.0) as usize, + ) } } impl GGSWCiphertextCompressed { pub fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> { + let rank: usize = self.rank().into(); GLWECiphertextCompressed { data: self.data.at(row, col), - basek: self.basek, k: self.k, - rank: self.rank(), - seed: self.seed[row * (self.rank() + 1) + col], + base2k: self.base2k, + rank: self.rank, + seed: self.seed[row * (rank + 1) + col], } } } impl GGSWCiphertextCompressed { pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> { - let rank: usize = self.rank(); + let rank: usize = self.rank().into(); GLWECiphertextCompressed { data: self.data.at_mut(row, col), - basek: self.basek, k: self.k, - rank, + base2k: self.base2k, + rank: self.rank, seed: self.seed[row * (rank + 1) + col], } } } -impl Infos for GGSWCiphertextCompressed { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGSWCiphertextCompressed { - pub fn rank(&self) -> usize { - self.rank - } - - pub fn digits(&self) -> usize { - self.digits - } -} - impl ReaderFrom for GGSWCiphertextCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; - self.digits = reader.read_u64::()? as usize; - self.rank = reader.read_u64::()? as usize; - let seed_len = reader.read_u64::()? as usize; + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + self.digits = Digits(reader.read_u32::()?); + self.rank = Rank(reader.read_u32::()?); + let seed_len: usize = reader.read_u32::()? as usize; self.seed = vec![[0u8; 32]; seed_len]; for s in &mut self.seed { reader.read_exact(s)?; @@ -173,11 +203,11 @@ impl ReaderFrom for GGSWCiphertextCompressed { impl WriterTo for GGSWCiphertextCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; - writer.write_u64::(self.digits as u64)?; - writer.write_u64::(self.rank as u64)?; - writer.write_u64::(self.seed.len() as u64)?; + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; + writer.write_u32::(self.digits.into())?; + writer.write_u32::(self.rank.into())?; + writer.write_u32::(self.seed.len() as u32)?; for s in &self.seed { writer.write_all(s)?; } @@ -195,8 +225,8 @@ where assert_eq!(self.rank(), other.rank()) } - let rows: usize = self.rows(); - let rank: usize = self.rank(); + let rows: usize = self.rows().into(); + let rank: usize = self.rank().into(); (0..rows).for_each(|row_i| { (0..rank + 1).for_each(|col_j| { self.at_mut(row_i, col_j) diff --git a/poulpy-core/src/layouts/compressed/glwe_ct.rs b/poulpy-core/src/layouts/compressed/glwe_ct.rs index 8c8eaf9..30a3733 100644 --- a/poulpy-core/src/layouts/compressed/glwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/glwe_ct.rs @@ -1,25 +1,48 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, Reset, VecZnx, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, VecZnx, WriterTo, ZnxInfos}, source::Source, }; -use crate::layouts::{GLWECiphertext, Infos, compressed::Decompress}; +use crate::layouts::{Base2K, Degree, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision, compressed::Decompress}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct GLWECiphertextCompressed { pub(crate) data: VecZnx, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) rank: usize, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, + pub(crate) rank: Rank, pub(crate) seed: [u8; 32], } +impl LWEInfos for GLWECiphertextCompressed { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } +} +impl GLWEInfos for GLWECiphertextCompressed { + fn rank(&self) -> Rank { + self.rank + } +} + impl fmt::Debug for GLWECiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -27,75 +50,57 @@ impl fmt::Display for GLWECiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "GLWECiphertextCompressed: basek={} k={} rank={} seed={:?}: {}", - self.basek(), + "GLWECiphertextCompressed: base2k={} k={} rank={} seed={:?}: {}", + self.base2k(), self.k(), - self.rank, + self.rank(), self.seed, self.data ) } } -impl Reset for GLWECiphertextCompressed { - fn reset(&mut self) { - self.data.reset(); - self.basek = 0; - self.k = 0; - self.rank = 0; - self.seed = [0u8; 32]; - } -} - impl FillUniform for GLWECiphertextCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); } } -impl Infos for GLWECiphertextCompressed { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GLWECiphertextCompressed { - pub fn rank(&self) -> usize { - self.rank - } -} - impl GLWECiphertextCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { Self { - data: VecZnx::alloc(n, 1, k.div_ceil(basek)), - basek, + data: VecZnx::alloc(n.into(), 1, k.0.div_ceil(base2k.0) as usize), + base2k, k, rank, seed: [0u8; 32], } } - pub fn bytes_of(n: usize, basek: usize, k: usize) -> usize { - GLWECiphertext::bytes_of(n, basek, k, 1) + pub fn alloc_bytes(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k()) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { + VecZnx::alloc_bytes(n.into(), 1, k.0.div_ceil(base2k.0) as usize) } } impl ReaderFrom for GLWECiphertextCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; - self.rank = reader.read_u64::()? as usize; + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + self.rank = Rank(reader.read_u32::()?); reader.read_exact(&mut self.seed)?; self.data.read_from(reader) } @@ -103,9 +108,9 @@ impl ReaderFrom for GLWECiphertextCompressed { impl WriterTo for GLWECiphertextCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; - writer.write_u64::(self.rank as u64)?; + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; + writer.write_u32::(self.rank.into())?; writer.write_all(&self.seed)?; self.data.write_to(writer) } @@ -118,14 +123,12 @@ where fn decompress(&mut self, module: &Module, other: &GLWECiphertextCompressed) { #[cfg(debug_assertions)] { - use poulpy_hal::layouts::ZnxInfos; - assert_eq!( self.n(), - other.data.n(), + other.n(), "invalid receiver: self.n()={} != other.n()={}", self.n(), - other.data.n() + other.n() ); assert_eq!( self.size(), @@ -164,15 +167,12 @@ impl GLWECiphertext { debug_assert_eq!(self.size(), other.size()); } - let k: usize = other.k; - let basek: usize = other.basek; - let cols: usize = other.rank() + 1; module.vec_znx_copy(&mut self.data, 0, &other.data, 0); - (1..cols).for_each(|i| { - module.vec_znx_fill_uniform(basek, &mut self.data, i, source); + (1..(other.rank() + 1).into()).for_each(|i| { + module.vec_znx_fill_uniform(other.base2k.into(), &mut self.data, i, source); }); - self.basek = basek; - self.k = k; + self.base2k = other.base2k; + self.k = other.k; } } diff --git a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs index fb7f959..b69667f 100644 --- a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs @@ -1,23 +1,62 @@ use std::fmt; use poulpy_hal::{ - api::{ - SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, - }, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{GLWEToLWESwitchingKey, Infos, compressed::GGLWESwitchingKeyCompressed}; +use crate::layouts::{ + Base2K, Degree, Digits, GGLWELayoutInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, + compressed::GGLWESwitchingKeyCompressed, +}; #[derive(PartialEq, Eq, Clone)] pub struct GLWEToLWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +impl LWEInfos for GLWEToLWESwitchingKeyCompressed { + fn base2k(&self) -> Base2K { + self.0.base2k() + } + + fn k(&self) -> TorusPrecision { + self.0.k() + } + + fn n(&self) -> Degree { + self.0.n() + } + fn size(&self) -> usize { + self.0.size() + } +} + +impl GLWEInfos for GLWEToLWESwitchingKeyCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GLWEToLWESwitchingKeyCompressed { + fn rank_in(&self) -> Rank { + self.0.rank_in() + } + + fn digits(&self) -> Digits { + self.0.digits() + } + + fn rank_out(&self) -> Rank { + self.0.rank_out() + } + + fn rows(&self) -> Rows { + self.0.rows() + } +} + impl fmt::Debug for GLWEToLWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -27,52 +66,12 @@ impl FillUniform for GLWEToLWESwitchingKeyCompressed { } } -impl Reset for GLWEToLWESwitchingKeyCompressed { - fn reset(&mut self) { - self.0.reset(); - } -} - impl fmt::Display for GLWEToLWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(GLWEToLWESwitchingKeyCompressed) {}", self.0) } } -impl Infos for GLWEToLWESwitchingKeyCompressed { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.0.inner() - } - - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { - self.0.k() - } -} - -impl GLWEToLWESwitchingKeyCompressed { - pub fn digits(&self) -> usize { - self.0.digits() - } - - pub fn rank(&self) -> usize { - self.0.rank() - } - - pub fn rank_in(&self) -> usize { - self.0.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.0.rank_out() - } -} - impl ReaderFrom for GLWEToLWESwitchingKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) @@ -86,31 +85,53 @@ impl WriterTo for GLWEToLWESwitchingKeyCompressed { } impl GLWEToLWESwitchingKeyCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc( - n, basek, k, rows, 1, rank_in, 1, + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" + ); + Self(GGLWESwitchingKeyCompressed::alloc(infos)) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_in: Rank) -> Self { + Self(GGLWESwitchingKeyCompressed::alloc_with( + n, + base2k, + k, + rows, + Digits(1), + rank_in, + Rank(1), )) } - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_in: usize) -> usize + pub fn alloc_bytes(infos: &A) -> usize where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc, + A: GGLWELayoutInfos, { - GLWEToLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank_in) + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" + ); + GGLWESwitchingKeyCompressed::alloc_bytes(infos) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_in: Rank) -> usize { + GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rows, Digits(1), rank_in, Rank(1)) } } diff --git a/poulpy-core/src/layouts/compressed/lwe_ct.rs b/poulpy-core/src/layouts/compressed/lwe_ct.rs index 159b107..e11b3f3 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/lwe_ct.rs @@ -2,25 +2,41 @@ use std::fmt; use poulpy_hal::{ api::ZnFillUniform, - layouts::{ - Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, Reset, VecZnx, WriterTo, ZnxInfos, ZnxView, ZnxViewMut, - }, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo, Zn, ZnxInfos, ZnxView, ZnxViewMut}, source::Source, }; -use crate::layouts::{Infos, LWECiphertext, SetMetaData, compressed::Decompress}; +use crate::layouts::{Base2K, Degree, LWECiphertext, LWEInfos, TorusPrecision, compressed::Decompress}; #[derive(PartialEq, Eq, Clone)] pub struct LWECiphertextCompressed { - pub(crate) data: VecZnx, - pub(crate) k: usize, - pub(crate) basek: usize, + pub(crate) data: Zn, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, pub(crate) seed: [u8; 32], } +impl LWEInfos for LWECiphertextCompressed { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} + impl fmt::Debug for LWECiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -28,8 +44,8 @@ impl fmt::Display for LWECiphertextCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "LWECiphertextCompressed: basek={} k={} seed={:?}: {}", - self.basek(), + "LWECiphertextCompressed: base2k={} k={} seed={:?}: {}", + self.base2k(), self.k(), self.seed, self.data @@ -37,18 +53,6 @@ impl fmt::Display for LWECiphertextCompressed { } } -impl Reset for LWECiphertextCompressed -where - VecZnx: Reset, -{ - fn reset(&mut self) { - self.data.reset(); - self.basek = 0; - self.k = 0; - self.seed = [0u8; 32]; - } -} - impl FillUniform for LWECiphertextCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); @@ -56,46 +60,31 @@ impl FillUniform for LWECiphertextCompressed { } impl LWECiphertextCompressed> { - pub fn alloc(basek: usize, k: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: LWEInfos, + { + Self::alloc_with(infos.base2k(), infos.k()) + } + + pub fn alloc_with(base2k: Base2K, k: TorusPrecision) -> Self { Self { - data: VecZnx::alloc(1, 1, k.div_ceil(basek)), + data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), k, - basek, + base2k, seed: [0u8; 32], } } -} -impl Infos for LWECiphertextCompressed -where - VecZnx: ZnxInfos, -{ - type Inner = VecZnx; - - fn n(&self) -> usize { - &self.inner().n() - 1 + pub fn alloc_bytes(infos: &A) -> usize + where + A: LWEInfos, + { + Self::alloc_bytes_with(infos.base2k(), infos.k()) } - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl SetMetaData for LWECiphertextCompressed { - fn set_k(&mut self, k: usize) { - self.k = k - } - - fn set_basek(&mut self, basek: usize) { - self.basek = basek + pub fn alloc_bytes_with(base2k: Base2K, k: TorusPrecision) -> usize { + Zn::alloc_bytes(1, 1, k.0.div_ceil(base2k.0) as usize) } } @@ -103,8 +92,8 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; impl ReaderFrom for LWECiphertextCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); reader.read_exact(&mut self.seed)?; self.data.read_from(reader) } @@ -112,8 +101,8 @@ impl ReaderFrom for LWECiphertextCompressed { impl WriterTo for LWECiphertextCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; writer.write_all(&self.seed)?; self.data.write_to(writer) } @@ -126,7 +115,13 @@ where fn decompress(&mut self, module: &Module, other: &LWECiphertextCompressed) { debug_assert_eq!(self.size(), other.size()); let mut source: Source = Source::new(other.seed); - module.zn_fill_uniform(self.n(), other.basek(), &mut self.data, 0, &mut source); + module.zn_fill_uniform( + self.n().into(), + other.base2k().into(), + &mut self.data, + 0, + &mut source, + ); (0..self.size()).for_each(|i| { self.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; }); diff --git a/poulpy-core/src/layouts/compressed/lwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_ksk.rs index 23ee722..c37fbdd 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_ksk.rs @@ -1,15 +1,11 @@ use poulpy_hal::{ - api::{ - SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, - }, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, + api::{VecZnxCopy, VecZnxFillUniform}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Infos, LWESwitchingKey, + Base2K, Degree, Digits, GGLWELayoutInfos, GLWEInfos, LWEInfos, LWESwitchingKey, Rank, Rows, TorusPrecision, compressed::{Decompress, GGLWESwitchingKeyCompressed}, }; use std::fmt; @@ -17,9 +13,49 @@ use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct LWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +impl LWEInfos for LWESwitchingKeyCompressed { + fn base2k(&self) -> Base2K { + self.0.base2k() + } + + fn k(&self) -> TorusPrecision { + self.0.k() + } + + fn n(&self) -> Degree { + self.0.n() + } + fn size(&self) -> usize { + self.0.size() + } +} +impl GLWEInfos for LWESwitchingKeyCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for LWESwitchingKeyCompressed { + fn digits(&self) -> Digits { + self.0.digits() + } + + fn rank_in(&self) -> Rank { + self.0.rank_in() + } + + fn rank_out(&self) -> Rank { + self.0.rank_out() + } + + fn rows(&self) -> Rows { + self.0.rows() + } +} + impl fmt::Debug for LWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -29,52 +65,12 @@ impl FillUniform for LWESwitchingKeyCompressed { } } -impl Reset for LWESwitchingKeyCompressed { - fn reset(&mut self) { - self.0.reset(); - } -} - impl fmt::Display for LWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(LWESwitchingKeyCompressed) {}", self.0) } } -impl Infos for LWESwitchingKeyCompressed { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.0.inner() - } - - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { - self.0.k() - } -} - -impl LWESwitchingKeyCompressed { - pub fn digits(&self) -> usize { - self.0.digits() - } - - pub fn rank(&self) -> usize { - self.0.rank() - } - - pub fn rank_in(&self) -> usize { - self.0.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.0.rank_out() - } -} - impl ReaderFrom for LWESwitchingKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) @@ -88,32 +84,64 @@ impl WriterTo for LWESwitchingKeyCompressed { } impl LWESwitchingKeyCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc( - n, basek, k, rows, 1, 1, 1, + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWESwitchingKeyCompressed" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKeyCompressed" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKeyCompressed" + ); + Self(GGLWESwitchingKeyCompressed::alloc(infos)) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows) -> Self { + Self(GGLWESwitchingKeyCompressed::alloc_with( + n, + base2k, + k, + rows, + Digits(1), + Rank(1), + Rank(1), )) } - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize + pub fn alloc_bytes(infos: &A) -> usize where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc, + A: GGLWELayoutInfos, { - LWESwitchingKey::encrypt_sk_scratch_space(module, basek, k) + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + GGLWESwitchingKeyCompressed::alloc_bytes(infos) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows) -> usize { + GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rows, Digits(1), Rank(1), Rank(1)) } } diff --git a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs index e9023c8..7b86fb0 100644 --- a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs @@ -1,15 +1,11 @@ use poulpy_hal::{ - api::{ - SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, - }, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, + api::{VecZnxCopy, VecZnxFillUniform}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo}, source::Source, }; use crate::layouts::{ - Infos, LWEToGLWESwitchingKey, + Base2K, Degree, Digits, GGLWELayoutInfos, GLWEInfos, LWEInfos, LWEToGLWESwitchingKey, Rank, Rows, TorusPrecision, compressed::{Decompress, GGLWESwitchingKeyCompressed}, }; use std::fmt; @@ -17,9 +13,49 @@ use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct LWEToGLWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +impl LWEInfos for LWEToGLWESwitchingKeyCompressed { + fn n(&self) -> Degree { + self.0.n() + } + + fn base2k(&self) -> Base2K { + self.0.base2k() + } + + fn k(&self) -> TorusPrecision { + self.0.k() + } + fn size(&self) -> usize { + self.0.size() + } +} +impl GLWEInfos for LWEToGLWESwitchingKeyCompressed { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for LWEToGLWESwitchingKeyCompressed { + fn digits(&self) -> Digits { + self.0.digits() + } + + fn rank_in(&self) -> Rank { + self.0.rank_in() + } + + fn rank_out(&self) -> Rank { + self.0.rank_out() + } + + fn rows(&self) -> Rows { + self.0.rows() + } +} + impl fmt::Debug for LWEToGLWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -29,52 +65,12 @@ impl FillUniform for LWEToGLWESwitchingKeyCompressed { } } -impl Reset for LWEToGLWESwitchingKeyCompressed { - fn reset(&mut self) { - self.0.reset(); - } -} - impl fmt::Display for LWEToGLWESwitchingKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(LWEToGLWESwitchingKeyCompressed) {}", self.0) } } -impl Infos for LWEToGLWESwitchingKeyCompressed { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.0.inner() - } - - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { - self.0.k() - } -} - -impl LWEToGLWESwitchingKeyCompressed { - pub fn digits(&self) -> usize { - self.0.digits() - } - - pub fn rank(&self) -> usize { - self.0.rank() - } - - pub fn rank_in(&self) -> usize { - self.0.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.0.rank_out() - } -} - impl ReaderFrom for LWEToGLWESwitchingKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) @@ -88,32 +84,54 @@ impl WriterTo for LWEToGLWESwitchingKeyCompressed { } impl LWEToGLWESwitchingKeyCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc( - n, basek, k, rows, 1, 1, rank_out, + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWEToGLWESwitchingKeyCompressed" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKeyCompressed" + ); + Self(GGLWESwitchingKeyCompressed::alloc(infos)) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_out: Rank) -> Self { + Self(GGLWESwitchingKeyCompressed::alloc_with( + n, + base2k, + k, + rows, + Digits(1), + Rank(1), + rank_out, )) } - pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_out: usize) -> usize + pub fn alloc_bytes(infos: &A) -> usize where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc, + A: GGLWELayoutInfos, { - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank_out) + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWEToGLWESwitchingKey" + ); + GGLWESwitchingKeyCompressed::alloc_bytes(infos) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_out: Rank) -> usize { + GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rows, Digits(1), Rank(1), rank_out) } } diff --git a/poulpy-core/src/layouts/gglwe_atk.rs b/poulpy-core/src/layouts/gglwe_atk.rs index 785e854..2177bca 100644 --- a/poulpy-core/src/layouts/gglwe_atk.rs +++ b/poulpy-core/src/layouts/gglwe_atk.rs @@ -1,22 +1,120 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{GGLWESwitchingKey, GLWECiphertext, Infos}; +use crate::layouts::{ + Base2K, Degree, Digits, GGLWELayoutInfos, GGLWESwitchingKey, GLWECiphertext, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, +}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GGLWEAutomorphismKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rows: Rows, + pub digits: Digits, + pub rank: Rank, +} + #[derive(PartialEq, Eq, Clone)] pub struct GGLWEAutomorphismKey { pub(crate) key: GGLWESwitchingKey, pub(crate) p: i64, } +impl GGLWEAutomorphismKey { + pub fn p(&self) -> i64 { + self.p + } +} + +impl LWEInfos for GGLWEAutomorphismKey { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} + +impl GLWEInfos for GGLWEAutomorphismKey { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWEAutomorphismKey { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn digits(&self) -> Digits { + self.key.digits() + } + + fn rows(&self) -> Rows { + self.key.rows() + } +} + +impl LWEInfos for GGLWEAutomorphismKeyLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } +} + +impl GLWEInfos for GGLWEAutomorphismKeyLayout { + fn rank(&self) -> Rank { + self.rank + } +} + +impl GGLWELayoutInfos for GGLWEAutomorphismKeyLayout { + fn rank_in(&self) -> Rank { + self.rank + } + + fn digits(&self) -> Digits { + self.digits + } + + fn rank_out(&self) -> Rank { + self.rank + } + + fn rows(&self) -> Rows { + self.rows + } +} + impl fmt::Debug for GGLWEAutomorphismKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -26,16 +124,6 @@ impl FillUniform for GGLWEAutomorphismKey { } } -impl Reset for GGLWEAutomorphismKey -where - MatZnx: Reset, -{ - fn reset(&mut self) { - self.key.reset(); - self.p = 0; - } -} - impl fmt::Display for GGLWEAutomorphismKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(AutomorphismKey: p={}) {}", self.p, self.key) @@ -43,53 +131,42 @@ impl fmt::Display for GGLWEAutomorphismKey { } impl GGLWEAutomorphismKey> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEAutomorphismKey" + ); GGLWEAutomorphismKey { - key: GGLWESwitchingKey::alloc(n, basek, k, rows, digits, rank, rank), + key: GGLWESwitchingKey::alloc(infos), p: 0, } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - GGLWESwitchingKey::bytes_of(n, basek, k, rows, digits, rank, rank) - } -} - -impl Infos for GGLWEAutomorphismKey { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.key.inner() + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self { + GGLWEAutomorphismKey { + key: GGLWESwitchingKey::alloc_with(n, base2k, k, rows, digits, rank, rank), + p: 0, + } } - fn basek(&self) -> usize { - self.key.basek() + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEAutomorphismKey" + ); + GGLWESwitchingKey::alloc_bytes(infos) } - fn k(&self) -> usize { - self.key.k() - } -} - -impl GGLWEAutomorphismKey { - pub fn p(&self) -> i64 { - self.p - } - - pub fn digits(&self) -> usize { - self.key.digits() - } - - pub fn rank(&self) -> usize { - self.key.rank() - } - - pub fn rank_in(&self) -> usize { - self.key.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.key.rank_out() + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize { + GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, rows, digits, rank, rank) } } diff --git a/poulpy-core/src/layouts/gglwe_ct.rs b/poulpy-core/src/layouts/gglwe_ct.rs index 5e64873..54efd4a 100644 --- a/poulpy-core/src/layouts/gglwe_ct.rs +++ b/poulpy-core/src/layouts/gglwe_ct.rs @@ -1,24 +1,249 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, WriterTo, ZnxInfos}, source::Source, }; -use crate::layouts::{GLWECiphertext, Infos}; +use crate::layouts::{Base2K, BuildError, Degree, Digits, GLWECiphertext, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; +pub trait GGLWELayoutInfos +where + Self: GLWEInfos, +{ + fn rows(&self) -> Rows; + fn digits(&self) -> Digits; + fn rank_in(&self) -> Rank; + fn rank_out(&self) -> Rank; + fn layout(&self) -> GGLWECiphertextLayout { + GGLWECiphertextLayout { + n: self.n(), + base2k: self.base2k(), + k: self.k(), + rank_in: self.rank_in(), + rank_out: self.rank_out(), + digits: self.digits(), + rows: self.rows(), + } + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GGLWECiphertextLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rows: Rows, + pub digits: Digits, + pub rank_in: Rank, + pub rank_out: Rank, +} + +impl LWEInfos for GGLWECiphertextLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } +} + +impl GLWEInfos for GGLWECiphertextLayout { + fn rank(&self) -> Rank { + self.rank_out + } +} + +impl GGLWELayoutInfos for GGLWECiphertextLayout { + fn rank_in(&self) -> Rank { + self.rank_in + } + + fn digits(&self) -> Digits { + self.digits + } + + fn rank_out(&self) -> Rank { + self.rank_out + } + + fn rows(&self) -> Rows { + self.rows + } +} + #[derive(PartialEq, Eq, Clone)] pub struct GGLWECiphertext { pub(crate) data: MatZnx, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) digits: usize, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, + pub(crate) digits: Digits, +} + +impl LWEInfos for GGLWECiphertext { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GGLWECiphertext { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWECiphertext { + fn rank_in(&self) -> Rank { + Rank(self.data.cols_in() as u32) + } + + fn rank_out(&self) -> Rank { + Rank(self.data.cols_out() as u32 - 1) + } + + fn digits(&self) -> Digits { + self.digits + } + + fn rows(&self) -> Rows { + Rows(self.data.rows() as u32) + } +} + +pub struct GGLWECiphertextBuilder { + data: Option>, + base2k: Option, + k: Option, + digits: Option, +} + +impl GGLWECiphertext { + #[inline] + pub fn builder() -> GGLWECiphertextBuilder { + GGLWECiphertextBuilder { + data: None, + base2k: None, + k: None, + digits: None, + } + } +} + +impl GGLWECiphertextBuilder> { + #[inline] + pub fn layout(mut self, infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + self.data = Some(MatZnx::alloc( + infos.n().into(), + infos.rows().into(), + infos.rank_in().into(), + (infos.rank_out() + 1).into(), + infos.size(), + )); + self.base2k = Some(infos.base2k()); + self.k = Some(infos.k()); + self.digits = Some(infos.digits()); + self + } +} + +impl GGLWECiphertextBuilder { + #[inline] + pub fn data(mut self, data: MatZnx) -> Self { + self.data = Some(data); + self + } + #[inline] + pub fn base2k(mut self, base2k: Base2K) -> Self { + self.base2k = Some(base2k); + self + } + #[inline] + pub fn k(mut self, k: TorusPrecision) -> Self { + self.k = Some(k); + self + } + + #[inline] + pub fn digits(mut self, digits: Digits) -> Self { + self.digits = Some(digits); + self + } + + pub fn build(self) -> Result, BuildError> { + let data: MatZnx = self.data.ok_or(BuildError::MissingData)?; + let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; + let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; + let digits: Digits = self.digits.ok_or(BuildError::MissingDigits)?; + + if base2k == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if digits == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if k == 0_u32 { + return Err(BuildError::ZeroTorusPrecision); + } + + if data.n() == 0 { + return Err(BuildError::ZeroDegree); + } + + if data.cols() == 0 { + return Err(BuildError::ZeroCols); + } + + if data.size() == 0 { + return Err(BuildError::ZeroLimbs); + } + + Ok(GGLWECiphertext { + data, + base2k, + k, + digits, + }) + } +} + +impl GGLWECiphertext { + pub fn data(&self) -> &MatZnx { + &self.data + } +} + +impl GGLWECiphertext { + pub fn data_mut(&mut self) -> &mut MatZnx { + &mut self.data + } } impl fmt::Debug for GGLWECiphertext { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -28,140 +253,156 @@ impl FillUniform for GGLWECiphertext { } } -impl Reset for GGLWECiphertext { - fn reset(&mut self) { - self.data.reset(); - self.basek = 0; - self.k = 0; - self.digits = 0; - } -} - impl fmt::Display for GGLWECiphertext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGLWECiphertext: basek={} k={} digits={}) {}", - self.basek, self.k, self.digits, self.data + "(GGLWECiphertext: k={} base2k={} digits={}) {}", + self.k().0, + self.base2k().0, + self.digits().0, + self.data ) } } impl GGLWECiphertext { pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { - GLWECiphertext { - data: self.data.at(row, col), - basek: self.basek, - k: self.k, - } + GLWECiphertext::builder() + .data(self.data.at(row, col)) + .base2k(self.base2k()) + .k(self.k()) + .build() + .unwrap() } } impl GGLWECiphertext { pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext { - data: self.data.at_mut(row, col), - basek: self.basek, - k: self.k, - } + GLWECiphertext::builder() + .base2k(self.base2k()) + .k(self.k()) + .data(self.data.at_mut(row, col)) + .build() + .unwrap() } } impl GGLWECiphertext> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { - let size: usize = k.div_ceil(basek); + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + Self::alloc_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank_in(), + infos.rank_out(), + ) + } + + pub fn alloc_with( + n: Degree, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, + ) -> Self { + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid gglwe: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid gglwe: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); Self { - data: MatZnx::alloc(n, rows, rank_in, rank_out + 1, size), - basek, + data: MatZnx::alloc( + n.into(), + rows.into(), + rank_in.into(), + (rank_out + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ), k, + base2k, digits, } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize { - let size: usize = k.div_ceil(basek); + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + Self::alloc_bytes_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank_in(), + infos.rank_out(), + ) + } + + pub fn alloc_bytes_with( + n: Degree, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, + ) -> usize { + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid gglwe: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid gglwe: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); - MatZnx::alloc_bytes(n, rows, rank_in, rank_out + 1, rows) - } -} - -impl Infos for GGLWECiphertext { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGLWECiphertext { - pub fn rank(&self) -> usize { - self.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.digits - } - - pub fn rank_in(&self) -> usize { - self.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.data.cols_out() - 1 + MatZnx::alloc_bytes( + n.into(), + rows.into(), + rank_in.into(), + (rank_out + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ) } } impl ReaderFrom for GGLWECiphertext { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; - self.digits = reader.read_u64::()? as usize; + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + self.digits = Digits(reader.read_u32::()?); self.data.read_from(reader) } } impl WriterTo for GGLWECiphertext { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; - writer.write_u64::(self.digits as u64)?; + writer.write_u32::(self.k.0)?; + writer.write_u32::(self.base2k.0)?; + writer.write_u32::(self.digits.0)?; self.data.write_to(writer) } } diff --git a/poulpy-core/src/layouts/gglwe_ksk.rs b/poulpy-core/src/layouts/gglwe_ksk.rs index 5652904..3f3e8b0 100644 --- a/poulpy-core/src/layouts/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/gglwe_ksk.rs @@ -1,13 +1,64 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{GGLWECiphertext, GLWECiphertext, Infos}; +use crate::layouts::{ + Base2K, Degree, Digits, GGLWECiphertext, GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, +}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GGLWESwitchingKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rows: Rows, + pub digits: Digits, + pub rank_in: Rank, + pub rank_out: Rank, +} + +impl LWEInfos for GGLWESwitchingKeyLayout { + fn n(&self) -> Degree { + self.n + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } +} + +impl GLWEInfos for GGLWESwitchingKeyLayout { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWESwitchingKeyLayout { + fn rank_in(&self) -> Rank { + self.rank_in + } + + fn rank_out(&self) -> Rank { + self.rank_out + } + + fn digits(&self) -> Digits { + self.digits + } + + fn rows(&self) -> Rows { + self.rows + } +} + #[derive(PartialEq, Eq, Clone)] pub struct GGLWESwitchingKey { pub(crate) key: GGLWECiphertext, @@ -15,9 +66,51 @@ pub struct GGLWESwitchingKey { pub(crate) sk_out_n: usize, // Degree of sk_out } +impl LWEInfos for GGLWESwitchingKey { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} + +impl GLWEInfos for GGLWESwitchingKey { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWESwitchingKey { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn digits(&self) -> Digits { + self.key.digits() + } + + fn rows(&self) -> Rows { + self.key.rows() + } +} + impl fmt::Debug for GGLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -26,7 +119,9 @@ impl fmt::Display for GGLWESwitchingKey { write!( f, "(GLWESwitchingKey: sk_in_n={} sk_out_n={}) {}", - self.sk_in_n, self.sk_out_n, self.key.data + self.sk_in_n, + self.sk_out_n, + self.key.data() ) } } @@ -37,70 +132,51 @@ impl FillUniform for GGLWESwitchingKey { } } -impl Reset for GGLWESwitchingKey -where - MatZnx: Reset, -{ - fn reset(&mut self) { - self.key.reset(); - self.sk_in_n = 0; - self.sk_out_n = 0; - } -} - impl GGLWESwitchingKey> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { GGLWESwitchingKey { - key: GGLWECiphertext::alloc(n, basek, k, rows, digits, rank_in, rank_out), + key: GGLWECiphertext::alloc(infos), sk_in_n: 0, sk_out_n: 0, } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize { - GGLWECiphertext::>::bytes_of(n, basek, k, rows, digits, rank_in, rank_out) - } -} - -impl Infos for GGLWESwitchingKey { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.key.inner() + pub fn alloc_with( + n: Degree, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, + ) -> Self { + GGLWESwitchingKey { + key: GGLWECiphertext::alloc_with(n, base2k, k, rows, digits, rank_in, rank_out), + sk_in_n: 0, + sk_out_n: 0, + } } - fn basek(&self) -> usize { - self.key.basek() + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + GGLWECiphertext::alloc_bytes(infos) } - fn k(&self) -> usize { - self.key.k() - } -} - -impl GGLWESwitchingKey { - pub fn rank(&self) -> usize { - self.key.data.cols_out() - 1 - } - - pub fn rank_in(&self) -> usize { - self.key.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.key.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.key.digits() - } - - pub fn sk_degree_in(&self) -> usize { - self.sk_in_n - } - - pub fn sk_degree_out(&self) -> usize { - self.sk_out_n + pub fn alloc_bytes_with( + n: Degree, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, + ) -> usize { + GGLWECiphertext::alloc_bytes_with(n, base2k, k, rows, digits, rank_in, rank_out) } } diff --git a/poulpy-core/src/layouts/gglwe_tsk.rs b/poulpy-core/src/layouts/gglwe_tsk.rs index 47c2993..b6a7450 100644 --- a/poulpy-core/src/layouts/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/gglwe_tsk.rs @@ -1,21 +1,113 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{GGLWESwitchingKey, Infos}; +use crate::layouts::{ + Base2K, Degree, Digits, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, +}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GGLWETensorKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rows: Rows, + pub digits: Digits, + pub rank: Rank, +} + #[derive(PartialEq, Eq, Clone)] pub struct GGLWETensorKey { pub(crate) keys: Vec>, } +impl LWEInfos for GGLWETensorKey { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for GGLWETensorKey { + fn rank(&self) -> Rank { + self.keys[0].rank_out() + } +} + +impl GGLWELayoutInfos for GGLWETensorKey { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn digits(&self) -> Digits { + self.keys[0].digits() + } + + fn rows(&self) -> Rows { + self.keys[0].rows() + } +} + +impl LWEInfos for GGLWETensorKeyLayout { + fn n(&self) -> Degree { + self.n + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } +} + +impl GLWEInfos for GGLWETensorKeyLayout { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWETensorKeyLayout { + fn rank_in(&self) -> Rank { + self.rank + } + + fn digits(&self) -> Digits { + self.digits + } + + fn rank_out(&self) -> Rank { + self.rank + } + + fn rows(&self) -> Rows { + self.rows + } +} + impl fmt::Debug for GGLWETensorKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -27,74 +119,79 @@ impl FillUniform for GGLWETensorKey { } } -impl Reset for GGLWETensorKey -where - MatZnx: Reset, -{ - fn reset(&mut self) { - self.keys - .iter_mut() - .for_each(|key: &mut GGLWESwitchingKey| key.reset()) - } -} - impl fmt::Display for GGLWETensorKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKey)",)?; for (i, key) in self.keys.iter().enumerate() { - write!(f, "{}: {}", i, key)?; + write!(f, "{i}: {key}")?; } Ok(()) } } impl GGLWETensorKey> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWETensorKey" + ); + Self::alloc_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank_out(), + ) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self { let mut keys: Vec>> = Vec::new(); - let pairs: usize = (((rank + 1) * rank) >> 1).max(1); + let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKey::alloc(n, basek, k, rows, digits, 1, rank)); + keys.push(GGLWESwitchingKey::alloc_with( + n, + base2k, + k, + rows, + digits, + Rank(1), + rank, + )); }); Self { keys } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GGLWESwitchingKey::>::bytes_of(n, basek, k, rows, digits, 1, rank) - } -} - -impl Infos for GGLWETensorKey { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.keys[0].inner() + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWETensorKey" + ); + let rank_out: usize = infos.rank_out().into(); + let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); + pairs + * GGLWESwitchingKey::alloc_bytes_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + Rank(1), + infos.rank_out(), + ) } - fn basek(&self) -> usize { - self.keys[0].basek() - } - - fn k(&self) -> usize { - self.keys[0].k() - } -} - -impl GGLWETensorKey { - pub fn rank(&self) -> usize { - self.keys[0].rank() - } - - pub fn rank_in(&self) -> usize { - self.keys[0].rank_in() - } - - pub fn rank_out(&self) -> usize { - self.keys[0].rank_out() - } - - pub fn digits(&self) -> usize { - self.keys[0].digits() + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize { + let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; + pairs * GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, rows, digits, Rank(1), rank) } } @@ -104,7 +201,7 @@ impl GGLWETensorKey { if i > j { std::mem::swap(&mut i, &mut j); }; - let rank: usize = self.rank(); + let rank: usize = self.rank_out().into(); &mut self.keys[i * rank + j - (i * (i + 1) / 2)] } } @@ -115,7 +212,7 @@ impl GGLWETensorKey { if i > j { std::mem::swap(&mut i, &mut j); }; - let rank: usize = self.rank(); + let rank: usize = self.rank_out().into(); &self.keys[i * rank + j - (i * (i + 1) / 2)] } } diff --git a/poulpy-core/src/layouts/ggsw_ct.rs b/poulpy-core/src/layouts/ggsw_ct.rs index 39e3cfc..7ccdc7d 100644 --- a/poulpy-core/src/layouts/ggsw_ct.rs +++ b/poulpy-core/src/layouts/ggsw_ct.rs @@ -1,17 +1,224 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, WriterTo, ZnxInfos}, source::Source, }; use std::fmt; -use crate::layouts::{GLWECiphertext, Infos}; +use crate::layouts::{Base2K, BuildError, Degree, Digits, GLWECiphertext, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision}; + +pub trait GGSWInfos +where + Self: GLWEInfos, +{ + fn rows(&self) -> Rows; + fn digits(&self) -> Digits; + fn layout(&self) -> GGSWCiphertextLayout { + GGSWCiphertextLayout { + n: self.n(), + base2k: self.base2k(), + k: self.k(), + rank: self.rank(), + rows: self.rows(), + digits: self.digits(), + } + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GGSWCiphertextLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rows: Rows, + pub digits: Digits, + pub rank: Rank, +} + +impl LWEInfos for GGSWCiphertextLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } +} +impl GLWEInfos for GGSWCiphertextLayout { + fn rank(&self) -> Rank { + self.rank + } +} + +impl GGSWInfos for GGSWCiphertextLayout { + fn digits(&self) -> Digits { + self.digits + } + + fn rows(&self) -> Rows { + self.rows + } +} #[derive(PartialEq, Eq, Clone)] pub struct GGSWCiphertext { pub(crate) data: MatZnx, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) digits: usize, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, + pub(crate) digits: Digits, +} + +impl LWEInfos for GGSWCiphertext { + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GGSWCiphertext { + fn rank(&self) -> Rank { + Rank(self.data.cols_out() as u32 - 1) + } +} + +impl GGSWInfos for GGSWCiphertext { + fn digits(&self) -> Digits { + self.digits + } + + fn rows(&self) -> Rows { + Rows(self.data.rows() as u32) + } +} + +pub struct GGSWCiphertextBuilder { + data: Option>, + base2k: Option, + k: Option, + digits: Option, +} + +impl GGSWCiphertext { + #[inline] + pub fn builder() -> GGSWCiphertextBuilder { + GGSWCiphertextBuilder { + data: None, + base2k: None, + k: None, + digits: None, + } + } +} + +impl GGSWCiphertextBuilder> { + #[inline] + pub fn layout(mut self, infos: &A) -> Self + where + A: GGSWInfos, + { + debug_assert!( + infos.size() as u32 > infos.digits().0, + "invalid ggsw: ceil(k/base2k): {} <= digits: {}", + infos.size(), + infos.digits() + ); + + assert!( + infos.rows().0 * infos.digits().0 <= infos.size() as u32, + "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {}", + infos.rows(), + infos.digits(), + infos.size(), + ); + + self.data = Some(MatZnx::alloc( + infos.n().into(), + infos.rows().into(), + (infos.rank() + 1).into(), + (infos.rank() + 1).into(), + infos.size(), + )); + self.base2k = Some(infos.base2k()); + self.k = Some(infos.k()); + self.digits = Some(infos.digits()); + self + } +} + +impl GGSWCiphertextBuilder { + #[inline] + pub fn data(mut self, data: MatZnx) -> Self { + self.data = Some(data); + self + } + #[inline] + pub fn base2k(mut self, base2k: Base2K) -> Self { + self.base2k = Some(base2k); + self + } + #[inline] + pub fn k(mut self, k: TorusPrecision) -> Self { + self.k = Some(k); + self + } + + #[inline] + pub fn digits(mut self, digits: Digits) -> Self { + self.digits = Some(digits); + self + } + + pub fn build(self) -> Result, BuildError> { + let data: MatZnx = self.data.ok_or(BuildError::MissingData)?; + let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; + let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; + let digits: Digits = self.digits.ok_or(BuildError::MissingDigits)?; + + if base2k == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if digits == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if k == 0_u32 { + return Err(BuildError::ZeroTorusPrecision); + } + + if data.n() == 0 { + return Err(BuildError::ZeroDegree); + } + + if data.cols() == 0 { + return Err(BuildError::ZeroCols); + } + + if data.size() == 0 { + return Err(BuildError::ZeroLimbs); + } + + Ok(GGSWCiphertext { + data, + base2k, + k, + digits, + }) + } } impl fmt::Debug for GGSWCiphertext { @@ -24,21 +231,15 @@ impl fmt::Display for GGSWCiphertext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "(GGSWCiphertext: basek={} k={} digits={}) {}", - self.basek, self.k, self.digits, self.data + "(GGSWCiphertext: k: {} base2k: {} digits: {}) {}", + self.k().0, + self.base2k().0, + self.digits().0, + self.data ) } } -impl Reset for GGSWCiphertext { - fn reset(&mut self) { - self.data.reset(); - self.basek = 0; - self.k = 0; - self.digits = 0; - } -} - impl FillUniform for GGSWCiphertext { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); @@ -47,96 +248,106 @@ impl FillUniform for GGSWCiphertext { impl GGSWCiphertext { pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { - GLWECiphertext { - data: self.data.at(row, col), - basek: self.basek, - k: self.k, - } + GLWECiphertext::builder() + .data(self.data.at(row, col)) + .base2k(self.base2k()) + .k(self.k()) + .build() + .unwrap() } } impl GGSWCiphertext { pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext { - data: self.data.at_mut(row, col), - basek: self.basek, - k: self.k, - } + GLWECiphertext::builder() + .base2k(self.base2k()) + .k(self.k()) + .data(self.data.at_mut(row, col)) + .build() + .unwrap() } } impl GGSWCiphertext> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { - let size: usize = k.div_ceil(basek); - debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); + pub fn alloc(infos: &A) -> Self + where + A: GGSWInfos, + { + Self::alloc_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank(), + ) + } + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self { + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid ggsw: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid ggsw: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); Self { - data: MatZnx::alloc(n, rows, rank + 1, rank + 1, k.div_ceil(basek)), - basek, + data: MatZnx::alloc( + n.into(), + rows.into(), + (rank + 1).into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ), k, + base2k, digits, } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { - let size: usize = k.div_ceil(basek); + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGSWInfos, + { + Self::alloc_bytes_with( + infos.n(), + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank(), + ) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize { + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid ggsw: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid ggsw: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); - MatZnx::alloc_bytes(n, rows, rank + 1, rank + 1, size) - } -} - -impl Infos for GGSWCiphertext { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGSWCiphertext { - pub fn rank(&self) -> usize { - self.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.digits + MatZnx::alloc_bytes( + n.into(), + rows.into(), + (rank + 1).into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ) } } @@ -144,18 +355,18 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; impl ReaderFrom for GGSWCiphertext { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; - self.digits = reader.read_u64::()? as usize; + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); + self.digits = Digits(reader.read_u32::()?); self.data.read_from(reader) } } impl WriterTo for GGSWCiphertext { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; - writer.write_u64::(self.digits as u64)?; + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; + writer.write_u32::(self.digits.into())?; self.data.write_to(writer) } } diff --git a/poulpy-core/src/layouts/glwe_ct.rs b/poulpy-core/src/layouts/glwe_ct.rs index a19deb9..23b6ef9 100644 --- a/poulpy-core/src/layouts/glwe_ct.rs +++ b/poulpy-core/src/layouts/glwe_ct.rs @@ -1,17 +1,193 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo}, + layouts::{ + Data, DataMut, DataRef, FillUniform, ReaderFrom, ToOwnedDeep, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos, + }, source::Source, }; -use crate::layouts::{Infos, SetMetaData}; +use crate::layouts::{Base2K, BuildError, Degree, LWEInfos, Rank, TorusPrecision}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; +pub trait GLWEInfos +where + Self: LWEInfos, +{ + fn rank(&self) -> Rank; + fn glwe_layout(&self) -> GLWECiphertextLayout { + GLWECiphertextLayout { + n: self.n(), + base2k: self.base2k(), + k: self.k(), + rank: self.rank(), + } + } +} + +pub trait GLWELayoutSet { + fn set_k(&mut self, k: TorusPrecision); + fn set_basek(&mut self, base2k: Base2K); +} + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GLWECiphertextLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rank: Rank, +} + +impl LWEInfos for GLWECiphertextLayout { + fn n(&self) -> Degree { + self.n + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } +} + +impl GLWEInfos for GLWECiphertextLayout { + fn rank(&self) -> Rank { + self.rank + } +} + #[derive(PartialEq, Eq, Clone)] pub struct GLWECiphertext { - pub data: VecZnx, - pub basek: usize, - pub k: usize, + pub(crate) data: VecZnx, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, +} + +impl GLWELayoutSet for GLWECiphertext { + fn set_basek(&mut self, base2k: Base2K) { + self.base2k = base2k + } + + fn set_k(&mut self, k: TorusPrecision) { + self.k = k + } +} + +impl GLWECiphertext { + pub fn data(&self) -> &VecZnx { + &self.data + } +} + +impl GLWECiphertext { + pub fn data_mut(&mut self) -> &mut VecZnx { + &mut self.data + } +} + +pub struct GLWECiphertextBuilder { + data: Option>, + base2k: Option, + k: Option, +} + +impl GLWECiphertext { + #[inline] + pub fn builder() -> GLWECiphertextBuilder { + GLWECiphertextBuilder { + data: None, + base2k: None, + k: None, + } + } +} + +impl GLWECiphertextBuilder> { + #[inline] + pub fn layout(mut self, layout: &A) -> Self + where + A: GLWEInfos, + { + self.data = Some(VecZnx::alloc( + layout.n().into(), + (layout.rank() + 1).into(), + layout.size(), + )); + self.base2k = Some(layout.base2k()); + self.k = Some(layout.k()); + self + } +} + +impl GLWECiphertextBuilder { + #[inline] + pub fn data(mut self, data: VecZnx) -> Self { + self.data = Some(data); + self + } + #[inline] + pub fn base2k(mut self, base2k: Base2K) -> Self { + self.base2k = Some(base2k); + self + } + #[inline] + pub fn k(mut self, k: TorusPrecision) -> Self { + self.k = Some(k); + self + } + + pub fn build(self) -> Result, BuildError> { + let data: VecZnx = self.data.ok_or(BuildError::MissingData)?; + let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; + let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; + + if base2k == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if k == 0_u32 { + return Err(BuildError::ZeroTorusPrecision); + } + + if data.n() == 0 { + return Err(BuildError::ZeroDegree); + } + + if data.cols() == 0 { + return Err(BuildError::ZeroCols); + } + + if data.size() == 0 { + return Err(BuildError::ZeroLimbs); + } + + Ok(GLWECiphertext { data, base2k, k }) + } +} + +impl LWEInfos for GLWECiphertext { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GLWECiphertext { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32 - 1) + } } impl ToOwnedDeep for GLWECiphertext { @@ -19,15 +195,15 @@ impl ToOwnedDeep for GLWECiphertext { fn to_owned_deep(&self) -> Self::Owned { GLWECiphertext { data: self.data.to_owned_deep(), - basek: self.basek, k: self.k, + base2k: self.base2k, } } } impl fmt::Debug for GLWECiphertext { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -35,25 +211,14 @@ impl fmt::Display for GLWECiphertext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "GLWECiphertext: basek={} k={}: {}", - self.basek(), - self.k(), + "GLWECiphertext: base2k={} k={}: {}", + self.base2k().0, + self.k().0, self.data ) } } -impl Reset for GLWECiphertext -where - VecZnx: Reset, -{ - fn reset(&mut self) { - self.data.reset(); - self.basek = 0; - self.k = 0; - } -} - impl FillUniform for GLWECiphertext { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); @@ -61,91 +226,75 @@ impl FillUniform for GLWECiphertext { } impl GLWECiphertext> { - pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { Self { - data: VecZnx::alloc(n, rank + 1, k.div_ceil(basek)), - basek, + data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), + base2k, k, } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rank: usize) -> usize { - VecZnx::alloc_bytes(n, rank + 1, k.div_ceil(basek)) + pub fn alloc_bytes(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + VecZnx::alloc_bytes(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) } } -impl Infos for GLWECiphertext { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GLWECiphertext { - pub fn rank(&self) -> usize { - self.cols() - 1 - } -} - -impl SetMetaData for GLWECiphertext { - fn set_k(&mut self, k: usize) { - self.k = k - } - - fn set_basek(&mut self, basek: usize) { - self.basek = basek - } -} - -pub trait GLWECiphertextToRef: Infos { +pub trait GLWECiphertextToRef { fn to_ref(&self) -> GLWECiphertext<&[u8]>; } impl GLWECiphertextToRef for GLWECiphertext { fn to_ref(&self) -> GLWECiphertext<&[u8]> { - GLWECiphertext { - data: self.data.to_ref(), - basek: self.basek, - k: self.k, - } + GLWECiphertext::builder() + .k(self.k()) + .base2k(self.base2k()) + .data(self.data.to_ref()) + .build() + .unwrap() } } -pub trait GLWECiphertextToMut: Infos { +pub trait GLWECiphertextToMut { fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>; } impl GLWECiphertextToMut for GLWECiphertext { fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext { - data: self.data.to_mut(), - basek: self.basek, - k: self.k, - } + GLWECiphertext::builder() + .k(self.k()) + .base2k(self.base2k()) + .data(self.data.to_mut()) + .build() + .unwrap() } } impl ReaderFrom for GLWECiphertext { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); self.data.read_from(reader) } } impl WriterTo for GLWECiphertext { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; + writer.write_u32::(self.k.0)?; + writer.write_u32::(self.base2k.0)?; self.data.write_to(writer) } } diff --git a/poulpy-core/src/layouts/glwe_pk.rs b/poulpy-core/src/layouts/glwe_pk.rs index 1e9fbb6..fc4b0fa 100644 --- a/poulpy-core/src/layouts/glwe_pk.rs +++ b/poulpy-core/src/layouts/glwe_pk.rs @@ -1,57 +1,193 @@ -use poulpy_hal::layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo}; +use poulpy_hal::layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo, ZnxInfos}; -use crate::{dist::Distribution, layouts::Infos}; +use crate::{ + dist::Distribution, + layouts::{Base2K, BuildError, Degree, GLWEInfos, LWEInfos, Rank, TorusPrecision}, +}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; #[derive(PartialEq, Eq)] pub struct GLWEPublicKey { pub(crate) data: VecZnx, - pub(crate) basek: usize, - pub(crate) k: usize, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, pub(crate) dist: Distribution, } +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GLWEPublicKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rank: Rank, +} + +impl LWEInfos for GLWEPublicKey { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GLWEPublicKey { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32 - 1) + } +} + +impl LWEInfos for GLWEPublicKeyLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } + + fn size(&self) -> usize { + self.k.0.div_ceil(self.base2k.0) as usize + } +} + +impl GLWEInfos for GLWEPublicKeyLayout { + fn rank(&self) -> Rank { + self.rank + } +} + +pub struct GLWEPublicKeyBuilder { + data: Option>, + base2k: Option, + k: Option, +} + +impl GLWEPublicKey { + #[inline] + pub fn builder() -> GLWEPublicKeyBuilder { + GLWEPublicKeyBuilder { + data: None, + base2k: None, + k: None, + } + } +} + +impl GLWEPublicKeyBuilder> { + #[inline] + pub fn layout(mut self, layout: &A) -> Self + where + A: GLWEInfos, + { + self.data = Some(VecZnx::alloc( + layout.n().into(), + (layout.rank() + 1).into(), + layout.size(), + )); + self.base2k = Some(layout.base2k()); + self.k = Some(layout.k()); + self + } +} + +impl GLWEPublicKeyBuilder { + #[inline] + pub fn data(mut self, data: VecZnx) -> Self { + self.data = Some(data); + self + } + #[inline] + pub fn base2k(mut self, base2k: Base2K) -> Self { + self.base2k = Some(base2k); + self + } + #[inline] + pub fn k(mut self, k: TorusPrecision) -> Self { + self.k = Some(k); + self + } + + pub fn build(self) -> Result, BuildError> { + let data: VecZnx = self.data.ok_or(BuildError::MissingData)?; + let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; + let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; + + if base2k == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if k == 0_u32 { + return Err(BuildError::ZeroTorusPrecision); + } + + if data.n() == 0 { + return Err(BuildError::ZeroDegree); + } + + if data.cols() == 0 { + return Err(BuildError::ZeroCols); + } + + if data.size() == 0 { + return Err(BuildError::ZeroLimbs); + } + + Ok(GLWEPublicKey { + data, + base2k, + k, + dist: Distribution::NONE, + }) + } +} + impl GLWEPublicKey> { - pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { Self { - data: VecZnx::alloc(n, rank + 1, k.div_ceil(basek)), - basek, + data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), + base2k, k, dist: Distribution::NONE, } } - pub fn bytes_of(n: usize, basek: usize, k: usize, rank: usize) -> usize { - VecZnx::alloc_bytes(n, rank + 1, k.div_ceil(basek)) - } -} - -impl Infos for GLWEPublicKey { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data + pub fn alloc_bytes(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k(), infos.rank()) } - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GLWEPublicKey { - pub fn rank(&self) -> usize { - self.cols() - 1 + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + VecZnx::alloc_bytes(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) } } impl ReaderFrom for GLWEPublicKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); match Distribution::read_from(reader) { Ok(dist) => self.dist = dist, Err(e) => return Err(e), @@ -62,8 +198,8 @@ impl ReaderFrom for GLWEPublicKey { impl WriterTo for GLWEPublicKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; + writer.write_u32::(self.k.0)?; + writer.write_u32::(self.base2k.0)?; match self.dist.write_to(writer) { Ok(()) => {} Err(e) => return Err(e), diff --git a/poulpy-core/src/layouts/glwe_pt.rs b/poulpy-core/src/layouts/glwe_pt.rs index 58fb398..b565055 100644 --- a/poulpy-core/src/layouts/glwe_pt.rs +++ b/poulpy-core/src/layouts/glwe_pt.rs @@ -1,83 +1,202 @@ use std::fmt; -use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef}; +use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}; -use crate::layouts::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, Infos, SetMetaData}; +use crate::layouts::{ + Base2K, BuildError, Degree, GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, GLWELayoutSet, LWEInfos, + Rank, TorusPrecision, +}; + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GLWEPlaintextLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, +} + +impl LWEInfos for GLWEPlaintextLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } +} + +impl GLWEInfos for GLWEPlaintextLayout { + fn rank(&self) -> Rank { + Rank(0) + } +} pub struct GLWEPlaintext { pub data: VecZnx, - pub basek: usize, - pub k: usize, + pub base2k: Base2K, + pub k: TorusPrecision, +} + +impl GLWELayoutSet for GLWEPlaintext { + fn set_basek(&mut self, base2k: Base2K) { + self.base2k = base2k + } + + fn set_k(&mut self, k: TorusPrecision) { + self.k = k + } +} + +impl LWEInfos for GLWEPlaintext { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } +} + +impl GLWEInfos for GLWEPlaintext { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32 - 1) + } +} + +pub struct GLWEPlaintextBuilder { + data: Option>, + base2k: Option, + k: Option, +} + +impl GLWEPlaintext { + #[inline] + pub fn builder() -> GLWEPlaintextBuilder { + GLWEPlaintextBuilder { + data: None, + base2k: None, + k: None, + } + } +} + +impl GLWEPlaintextBuilder { + #[inline] + pub fn data(mut self, data: VecZnx) -> Self { + self.data = Some(data); + self + } + #[inline] + pub fn base2k(mut self, base2k: Base2K) -> Self { + self.base2k = Some(base2k); + self + } + #[inline] + pub fn k(mut self, k: TorusPrecision) -> Self { + self.k = Some(k); + self + } + + pub fn build(self) -> Result, BuildError> { + let data: VecZnx = self.data.ok_or(BuildError::MissingData)?; + let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; + let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; + + if base2k.0 == 0 { + return Err(BuildError::ZeroBase2K); + } + + if k.0 == 0 { + return Err(BuildError::ZeroTorusPrecision); + } + + if data.n() == 0 { + return Err(BuildError::ZeroDegree); + } + + if data.cols() != 1 { + return Err(BuildError::ZeroCols); + } + + if data.size() == 0 { + return Err(BuildError::ZeroLimbs); + } + + Ok(GLWEPlaintext { data, base2k, k }) + } } impl fmt::Display for GLWEPlaintext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "GLWEPlaintext: basek={} k={}: {}", - self.basek(), - self.k(), + "GLWEPlaintext: base2k={} k={}: {}", + self.base2k().0, + self.k().0, self.data ) } } -impl Infos for GLWEPlaintext { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl SetMetaData for GLWEPlaintext { - fn set_k(&mut self, k: usize) { - self.k = k - } - - fn set_basek(&mut self, basek: usize) { - self.basek = basek - } -} - impl GLWEPlaintext> { - pub fn alloc(n: usize, basek: usize, k: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc_with(infos.n(), infos.base2k(), infos.k(), Rank(0)) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { + debug_assert!(rank.0 == 0); Self { - data: VecZnx::alloc(n, 1, k.div_ceil(basek)), - basek, + data: VecZnx::alloc(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize), + base2k, k, } } - pub fn byte_of(n: usize, basek: usize, k: usize) -> usize { - VecZnx::alloc_bytes(n, 1, k.div_ceil(basek)) + pub fn alloc_bytes(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k(), Rank(0)) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + debug_assert!(rank.0 == 0); + VecZnx::alloc_bytes(n.into(), (rank + 1).into(), k.0.div_ceil(base2k.0) as usize) } } impl GLWECiphertextToRef for GLWEPlaintext { fn to_ref(&self) -> GLWECiphertext<&[u8]> { - GLWECiphertext { - data: self.data.to_ref(), - basek: self.basek, - k: self.k, - } + GLWECiphertext::builder() + .data(self.data.to_ref()) + .k(self.k()) + .base2k(self.base2k()) + .build() + .unwrap() } } impl GLWECiphertextToMut for GLWEPlaintext { fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> { - GLWECiphertext { - data: self.data.to_mut(), - basek: self.basek, - k: self.k, - } + GLWECiphertext::builder() + .k(self.k()) + .base2k(self.base2k()) + .data(self.data.to_mut()) + .build() + .unwrap() } } diff --git a/poulpy-core/src/layouts/glwe_sk.rs b/poulpy-core/src/layouts/glwe_sk.rs index 8f8b10f..8870d35 100644 --- a/poulpy-core/src/layouts/glwe_sk.rs +++ b/poulpy-core/src/layouts/glwe_sk.rs @@ -3,7 +3,39 @@ use poulpy_hal::{ source::Source, }; -use crate::dist::Distribution; +use crate::{ + dist::Distribution, + layouts::{Base2K, Degree, GLWEInfos, LWEInfos, Rank, TorusPrecision}, +}; + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GLWESecretLayout { + pub n: Degree, + pub rank: Rank, +} + +impl LWEInfos for GLWESecretLayout { + fn base2k(&self) -> Base2K { + Base2K(0) + } + + fn k(&self) -> TorusPrecision { + TorusPrecision(0) + } + + fn n(&self) -> Degree { + self.n + } + + fn size(&self) -> usize { + 1 + } +} +impl GLWEInfos for GLWESecretLayout { + fn rank(&self) -> Rank { + self.rank + } +} #[derive(PartialEq, Eq, Clone)] pub struct GLWESecret { @@ -11,64 +43,88 @@ pub struct GLWESecret { pub(crate) dist: Distribution, } +impl LWEInfos for GLWESecret { + fn base2k(&self) -> Base2K { + Base2K(0) + } + + fn k(&self) -> TorusPrecision { + TorusPrecision(0) + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + 1 + } +} + +impl GLWEInfos for GLWESecret { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32) + } +} + impl GLWESecret> { - pub fn alloc(n: usize, rank: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc_with(infos.n(), infos.rank()) + } + + pub fn alloc_with(n: Degree, rank: Rank) -> Self { Self { - data: ScalarZnx::alloc(n, rank), + data: ScalarZnx::alloc(n.into(), rank.into()), dist: Distribution::NONE, } } - pub fn bytes_of(n: usize, rank: usize) -> usize { - ScalarZnx::alloc_bytes(n, rank) - } -} - -impl GLWESecret { - pub fn n(&self) -> usize { - self.data.n() + pub fn alloc_bytes(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::alloc_bytes_with(infos.n(), infos.rank()) } - pub fn log_n(&self) -> usize { - self.data.log_n() - } - - pub fn rank(&self) -> usize { - self.data.cols() + pub fn alloc_bytes_with(n: Degree, rank: Rank) -> usize { + ScalarZnx::alloc_bytes(n.into(), rank.into()) } } impl GLWESecret { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - (0..self.rank()).for_each(|i| { + (0..self.rank().into()).for_each(|i| { self.data.fill_ternary_prob(i, prob, source); }); self.dist = Distribution::TernaryProb(prob); } pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - (0..self.rank()).for_each(|i| { + (0..self.rank().into()).for_each(|i| { self.data.fill_ternary_hw(i, hw, source); }); self.dist = Distribution::TernaryFixed(hw); } pub fn fill_binary_prob(&mut self, prob: f64, source: &mut Source) { - (0..self.rank()).for_each(|i| { + (0..self.rank().into()).for_each(|i| { self.data.fill_binary_prob(i, prob, source); }); self.dist = Distribution::BinaryProb(prob); } pub fn fill_binary_hw(&mut self, hw: usize, source: &mut Source) { - (0..self.rank()).for_each(|i| { + (0..self.rank().into()).for_each(|i| { self.data.fill_binary_hw(i, hw, source); }); self.dist = Distribution::BinaryFixed(hw); } pub fn fill_binary_block(&mut self, block_size: usize, source: &mut Source) { - (0..self.rank()).for_each(|i| { + (0..self.rank().into()).for_each(|i| { self.data.fill_binary_block(i, block_size, source); }); self.dist = Distribution::BinaryBlock(block_size); diff --git a/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs index 8194a0a..00d5115 100644 --- a/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs @@ -1,19 +1,109 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{GGLWESwitchingKey, Infos}; +use crate::layouts::{ + Base2K, Degree, Digits, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, +}; use std::fmt; +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GLWEToLWESwitchingKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rows: Rows, + pub rank_in: Rank, +} + +impl LWEInfos for GLWEToLWESwitchingKeyLayout { + fn n(&self) -> Degree { + self.n + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } +} + +impl GLWEInfos for GLWEToLWESwitchingKeyLayout { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GLWEToLWESwitchingKeyLayout { + fn rank_in(&self) -> Rank { + self.rank_in + } + + fn digits(&self) -> Digits { + Digits(1) + } + + fn rank_out(&self) -> Rank { + Rank(1) + } + + fn rows(&self) -> Rows { + self.rows + } +} + /// A special [GLWESwitchingKey] required to for the conversion from [GLWECiphertext] to [LWECiphertext]. #[derive(PartialEq, Eq, Clone)] pub struct GLWEToLWESwitchingKey(pub(crate) GGLWESwitchingKey); +impl LWEInfos for GLWEToLWESwitchingKey { + fn base2k(&self) -> Base2K { + self.0.base2k() + } + + fn k(&self) -> TorusPrecision { + self.0.k() + } + + fn n(&self) -> Degree { + self.0.n() + } + + fn size(&self) -> usize { + self.0.size() + } +} + +impl GLWEInfos for GLWEToLWESwitchingKey { + fn rank(&self) -> Rank { + self.rank_out() + } +} +impl GGLWELayoutInfos for GLWEToLWESwitchingKey { + fn rank_in(&self) -> Rank { + self.0.rank_in() + } + + fn digits(&self) -> Digits { + self.0.digits() + } + + fn rank_out(&self) -> Rank { + self.0.rank_out() + } + + fn rows(&self) -> Rows { + self.0.rows() + } +} + impl fmt::Debug for GLWEToLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -23,52 +113,12 @@ impl FillUniform for GLWEToLWESwitchingKey { } } -impl Reset for GLWEToLWESwitchingKey { - fn reset(&mut self) { - self.0.reset(); - } -} - impl fmt::Display for GLWEToLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(GLWEToLWESwitchingKey) {}", self.0) } } -impl Infos for GLWEToLWESwitchingKey { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.0.inner() - } - - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { - self.0.k() - } -} - -impl GLWEToLWESwitchingKey { - pub fn digits(&self) -> usize { - self.0.digits() - } - - pub fn rank(&self) -> usize { - self.0.rank() - } - - pub fn rank_in(&self) -> usize { - self.0.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.0.rank_out() - } -} - impl ReaderFrom for GLWEToLWESwitchingKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) @@ -82,7 +132,53 @@ impl WriterTo for GLWEToLWESwitchingKey { } impl GLWEToLWESwitchingKey> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self { - Self(GGLWESwitchingKey::alloc(n, basek, k, rows, 1, rank_in, 1)) + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for GLWEToLWESwitchingKey" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for GLWEToLWESwitchingKey" + ); + Self(GGLWESwitchingKey::alloc(infos)) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_in: Rank) -> Self { + Self(GGLWESwitchingKey::alloc_with( + n, + base2k, + k, + rows, + Digits(1), + rank_in, + Rank(1), + )) + } + + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for GLWEToLWESwitchingKey" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for GLWEToLWESwitchingKey" + ); + GGLWESwitchingKey::alloc_bytes(infos) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_in: Rank) -> usize { + GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, rows, Digits(1), rank_in, Rank(1)) } } diff --git a/poulpy-core/src/layouts/infos.rs b/poulpy-core/src/layouts/infos.rs deleted file mode 100644 index 5473b81..0000000 --- a/poulpy-core/src/layouts/infos.rs +++ /dev/null @@ -1,54 +0,0 @@ -use poulpy_hal::layouts::ZnxInfos; - -pub trait Infos { - type Inner: ZnxInfos; - - fn inner(&self) -> &Self::Inner; - - /// Returns the ring degree of the polynomials. - fn n(&self) -> usize { - self.inner().n() - } - - /// Returns the base two logarithm of the ring dimension of the polynomials. - fn log_n(&self) -> usize { - self.inner().log_n() - } - - /// Returns the number of rows. - fn rows(&self) -> usize { - self.inner().rows() - } - - /// Returns the number of polynomials in each row. - fn cols(&self) -> usize { - self.inner().cols() - } - - fn rank(&self) -> usize { - self.cols() - 1 - } - - /// Returns the number of size per polynomial. - fn size(&self) -> usize { - let size: usize = self.inner().size(); - debug_assert!(size >= self.k().div_ceil(self.basek())); - size - } - - /// Returns the total number of small polynomials. - fn poly_count(&self) -> usize { - self.rows() * self.cols() * self.size() - } - - /// Returns the base 2 logarithm of the ciphertext base. - fn basek(&self) -> usize; - - /// Returns the bit precision of the ciphertext. - fn k(&self) -> usize; -} - -pub trait SetMetaData { - fn set_basek(&mut self, basek: usize); - fn set_k(&mut self, k: usize); -} diff --git a/poulpy-core/src/layouts/lwe_ct.rs b/poulpy-core/src/layouts/lwe_ct.rs index 3b7c48c..1560ea4 100644 --- a/poulpy-core/src/layouts/lwe_ct.rs +++ b/poulpy-core/src/layouts/lwe_ct.rs @@ -1,15 +1,75 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, Reset, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos}, source::Source, }; +use crate::layouts::{Base2K, BuildError, Degree, TorusPrecision}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +pub trait LWEInfos { + fn n(&self) -> Degree; + fn k(&self) -> TorusPrecision; + fn max_k(&self) -> TorusPrecision { + TorusPrecision(self.k().0 * self.size() as u32) + } + fn base2k(&self) -> Base2K; + fn size(&self) -> usize { + self.k().0.div_ceil(self.base2k().0) as usize + } + fn lwe_layout(&self) -> LWECiphertextLayout { + LWECiphertextLayout { + n: self.n(), + k: self.k(), + base2k: self.base2k(), + } + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct LWECiphertextLayout { + pub n: Degree, + pub k: TorusPrecision, + pub base2k: Base2K, +} + +impl LWEInfos for LWECiphertextLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } +} + #[derive(PartialEq, Eq, Clone)] pub struct LWECiphertext { pub(crate) data: Zn, - pub(crate) k: usize, - pub(crate) basek: usize, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, +} + +impl LWEInfos for LWECiphertext { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + fn n(&self) -> Degree { + Degree(self.data.n() as u32 - 1) + } + + fn size(&self) -> usize { + self.data.size() + } } impl LWECiphertext { @@ -26,7 +86,7 @@ impl LWECiphertext { impl fmt::Debug for LWECiphertext { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -34,22 +94,14 @@ impl fmt::Display for LWECiphertext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "LWECiphertext: basek={} k={}: {}", - self.basek(), - self.k(), + "LWECiphertext: base2k={} k={}: {}", + self.base2k().0, + self.k().0, self.data ) } } -impl Reset for LWECiphertext { - fn reset(&mut self) { - self.data.reset(); - self.basek = 0; - self.k = 0; - } -} - impl FillUniform for LWECiphertext where Zn: FillUniform, @@ -60,45 +112,106 @@ where } impl LWECiphertext> { - pub fn alloc(n: usize, basek: usize, k: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: LWEInfos, + { + Self::alloc_with(infos.n(), infos.base2k(), infos.k()) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self { Self { - data: Zn::alloc(n + 1, 1, k.div_ceil(basek)), + data: Zn::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), k, - basek, + base2k, + } + } + + pub fn alloc_bytes(infos: &A) -> usize + where + A: LWEInfos, + { + Self::alloc_bytes_with(infos.n(), infos.base2k(), infos.k()) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { + Zn::alloc_bytes((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) + } +} + +impl LWECiphertextBuilder> { + #[inline] + pub fn layout(mut self, layout: A) -> Self + where + A: LWEInfos, + { + self.data = Some(Zn::alloc((layout.n() + 1).into(), 1, layout.size())); + self.base2k = Some(layout.base2k()); + self.k = Some(layout.k()); + self + } +} + +pub struct LWECiphertextBuilder { + data: Option>, + base2k: Option, + k: Option, +} + +impl LWECiphertext { + #[inline] + pub fn builder() -> LWECiphertextBuilder { + LWECiphertextBuilder { + data: None, + base2k: None, + k: None, } } } -impl Infos for LWECiphertext -where - Zn: ZnxInfos, -{ - type Inner = Zn; - - fn n(&self) -> usize { - &self.inner().n() - 1 +impl LWECiphertextBuilder { + #[inline] + pub fn data(mut self, data: Zn) -> Self { + self.data = Some(data); + self + } + #[inline] + pub fn base2k(mut self, base2k: Base2K) -> Self { + self.base2k = Some(base2k); + self + } + #[inline] + pub fn k(mut self, k: TorusPrecision) -> Self { + self.k = Some(k); + self } - fn inner(&self) -> &Self::Inner { - &self.data - } + pub fn build(self) -> Result, BuildError> { + let data: Zn = self.data.ok_or(BuildError::MissingData)?; + let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; + let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; - fn basek(&self) -> usize { - self.basek - } + if base2k.0 == 0 { + return Err(BuildError::ZeroBase2K); + } - fn k(&self) -> usize { - self.k - } -} + if k.0 == 0 { + return Err(BuildError::ZeroTorusPrecision); + } -impl SetMetaData for LWECiphertext { - fn set_k(&mut self, k: usize) { - self.k = k - } + if data.n() == 0 { + return Err(BuildError::ZeroDegree); + } - fn set_basek(&mut self, basek: usize) { - self.basek = basek + if data.cols() == 0 { + return Err(BuildError::ZeroCols); + } + + if data.size() == 0 { + return Err(BuildError::ZeroLimbs); + } + + Ok(LWECiphertext { data, base2k, k }) } } @@ -108,11 +221,12 @@ pub trait LWECiphertextToRef { impl LWECiphertextToRef for LWECiphertext { fn to_ref(&self) -> LWECiphertext<&[u8]> { - LWECiphertext { - data: self.data.to_ref(), - basek: self.basek, - k: self.k, - } + LWECiphertext::builder() + .base2k(self.base2k()) + .k(self.k()) + .data(self.data.to_ref()) + .build() + .unwrap() } } @@ -123,30 +237,27 @@ pub trait LWECiphertextToMut { impl LWECiphertextToMut for LWECiphertext { fn to_mut(&mut self) -> LWECiphertext<&mut [u8]> { - LWECiphertext { - data: self.data.to_mut(), - basek: self.basek, - k: self.k, - } + LWECiphertext::builder() + .base2k(self.base2k()) + .k(self.k()) + .data(self.data.to_mut()) + .build() + .unwrap() } } -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; - -use crate::layouts::{Infos, SetMetaData}; - impl ReaderFrom for LWECiphertext { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; + self.k = TorusPrecision(reader.read_u32::()?); + self.base2k = Base2K(reader.read_u32::()?); self.data.read_from(reader) } } impl WriterTo for LWECiphertext { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; + writer.write_u32::(self.k.into())?; + writer.write_u32::(self.base2k.into())?; self.data.write_to(writer) } } diff --git a/poulpy-core/src/layouts/lwe_ksk.rs b/poulpy-core/src/layouts/lwe_ksk.rs index 632b43f..abf36cb 100644 --- a/poulpy-core/src/layouts/lwe_ksk.rs +++ b/poulpy-core/src/layouts/lwe_ksk.rs @@ -1,24 +1,170 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{GGLWESwitchingKey, Infos}; +use crate::layouts::{ + Base2K, Degree, Digits, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, +}; + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct LWESwitchingKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rows: Rows, +} + +impl LWEInfos for LWESwitchingKeyLayout { + fn n(&self) -> Degree { + self.n + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } +} + +impl GLWEInfos for LWESwitchingKeyLayout { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for LWESwitchingKeyLayout { + fn rank_in(&self) -> Rank { + Rank(1) + } + + fn digits(&self) -> Digits { + Digits(1) + } + + fn rank_out(&self) -> Rank { + Rank(1) + } + + fn rows(&self) -> Rows { + self.rows + } +} #[derive(PartialEq, Eq, Clone)] pub struct LWESwitchingKey(pub(crate) GGLWESwitchingKey); +impl LWEInfos for LWESwitchingKey { + fn base2k(&self) -> Base2K { + self.0.base2k() + } + + fn k(&self) -> TorusPrecision { + self.0.k() + } + + fn n(&self) -> Degree { + self.0.n() + } + + fn size(&self) -> usize { + self.0.size() + } +} + +impl GLWEInfos for LWESwitchingKey { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for LWESwitchingKey { + fn digits(&self) -> Digits { + self.0.digits() + } + + fn rank_in(&self) -> Rank { + self.0.rank_in() + } + + fn rank_out(&self) -> Rank { + self.0.rank_out() + } + + fn rows(&self) -> Rows { + self.0.rows() + } +} + impl LWESwitchingKey> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize) -> Self { - Self(GGLWESwitchingKey::alloc(n, basek, k, rows, 1, 1, 1)) + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + Self(GGLWESwitchingKey::alloc(infos)) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows) -> Self { + Self(GGLWESwitchingKey::alloc_with( + n, + base2k, + k, + rows, + Digits(1), + Rank(1), + Rank(1), + )) + } + + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + GGLWESwitchingKey::alloc_bytes(infos) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows) -> usize { + GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, rows, Digits(1), Rank(1), Rank(1)) } } impl fmt::Debug for LWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -28,52 +174,12 @@ impl FillUniform for LWESwitchingKey { } } -impl Reset for LWESwitchingKey { - fn reset(&mut self) { - self.0.reset(); - } -} - impl fmt::Display for LWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(LWESwitchingKey) {}", self.0) } } -impl Infos for LWESwitchingKey { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.0.inner() - } - - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { - self.0.k() - } -} - -impl LWESwitchingKey { - pub fn digits(&self) -> usize { - self.0.digits() - } - - pub fn rank(&self) -> usize { - self.0.rank() - } - - pub fn rank_in(&self) -> usize { - self.0.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.0.rank_out() - } -} - impl ReaderFrom for LWESwitchingKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) diff --git a/poulpy-core/src/layouts/lwe_pt.rs b/poulpy-core/src/layouts/lwe_pt.rs index f7a5cba..e739722 100644 --- a/poulpy-core/src/layouts/lwe_pt.rs +++ b/poulpy-core/src/layouts/lwe_pt.rs @@ -1,21 +1,70 @@ use std::fmt; -use poulpy_hal::layouts::{Data, DataMut, DataRef, Zn, ZnToMut, ZnToRef}; +use poulpy_hal::layouts::{Data, DataMut, DataRef, Zn, ZnToMut, ZnToRef, ZnxInfos}; -use crate::layouts::{Infos, SetMetaData}; +use crate::layouts::{Base2K, Degree, LWEInfos, TorusPrecision}; + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct LWEPlaintextLayout { + k: TorusPrecision, + base2k: Base2K, +} + +impl LWEInfos for LWEPlaintextLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + Degree(0) + } + + fn size(&self) -> usize { + self.k.0.div_ceil(self.base2k.0) as usize + } +} pub struct LWEPlaintext { pub(crate) data: Zn, - pub(crate) k: usize, - pub(crate) basek: usize, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, +} + +impl LWEInfos for LWEPlaintext { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32 - 1) + } + + fn size(&self) -> usize { + self.data.size() + } } impl LWEPlaintext> { - pub fn alloc(basek: usize, k: usize) -> Self { + pub fn alloc(infos: &A) -> Self + where + A: LWEInfos, + { + Self::alloc_with(infos.base2k(), infos.k()) + } + + pub fn alloc_with(base2k: Base2K, k: TorusPrecision) -> Self { Self { - data: Zn::alloc(1, 1, k.div_ceil(basek)), + data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), k, - basek, + base2k, } } } @@ -24,40 +73,14 @@ impl fmt::Display for LWEPlaintext { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "LWEPlaintext: basek={} k={}: {}", - self.basek(), - self.k(), + "LWEPlaintext: base2k={} k={}: {}", + self.base2k().0, + self.k().0, self.data ) } } -impl Infos for LWEPlaintext { - type Inner = Zn; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl SetMetaData for LWEPlaintext { - fn set_k(&mut self, k: usize) { - self.k = k - } - - fn set_basek(&mut self, basek: usize) { - self.basek = basek - } -} - pub trait LWEPlaintextToRef { #[allow(dead_code)] fn to_ref(&self) -> LWEPlaintext<&[u8]>; @@ -67,7 +90,7 @@ impl LWEPlaintextToRef for LWEPlaintext { fn to_ref(&self) -> LWEPlaintext<&[u8]> { LWEPlaintext { data: self.data.to_ref(), - basek: self.basek, + base2k: self.base2k, k: self.k, } } @@ -82,7 +105,7 @@ impl LWEPlaintextToMut for LWEPlaintext { fn to_mut(&mut self) -> LWEPlaintext<&mut [u8]> { LWEPlaintext { data: self.data.to_mut(), - basek: self.basek, + base2k: self.base2k, k: self.k, } } diff --git a/poulpy-core/src/layouts/lwe_sk.rs b/poulpy-core/src/layouts/lwe_sk.rs index 4502a22..a5b7d4e 100644 --- a/poulpy-core/src/layouts/lwe_sk.rs +++ b/poulpy-core/src/layouts/lwe_sk.rs @@ -3,7 +3,10 @@ use poulpy_hal::{ source::Source, }; -use crate::dist::Distribution; +use crate::{ + dist::Distribution, + layouts::{Base2K, Degree, LWEInfos, TorusPrecision}, +}; pub struct LWESecret { pub(crate) data: ScalarZnx, @@ -11,9 +14,9 @@ pub struct LWESecret { } impl LWESecret> { - pub fn alloc(n: usize) -> Self { + pub fn alloc(n: Degree) -> Self { Self { - data: ScalarZnx::alloc(n, 1), + data: ScalarZnx::alloc(n.into(), 1), dist: Distribution::NONE, } } @@ -33,17 +36,20 @@ impl LWESecret { } } -impl LWESecret { - pub fn n(&self) -> usize { - self.data.n() +impl LWEInfos for LWESecret { + fn base2k(&self) -> Base2K { + Base2K(0) + } + fn k(&self) -> TorusPrecision { + TorusPrecision(0) } - pub fn log_n(&self) -> usize { - self.data.log_n() + fn n(&self) -> Degree { + Degree(self.data.n() as u32) } - pub fn rank(&self) -> usize { - self.data.cols() + fn size(&self) -> usize { + 1 } } diff --git a/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs index af27cda..92bac13 100644 --- a/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs @@ -1,18 +1,108 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; -use crate::layouts::{GGLWESwitchingKey, Infos}; +use crate::layouts::{ + Base2K, Degree, Digits, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, +}; + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct LWEToGLWESwitchingKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rows: Rows, + pub rank_out: Rank, +} + +impl LWEInfos for LWEToGLWESwitchingKeyLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n + } +} + +impl GLWEInfos for LWEToGLWESwitchingKeyLayout { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for LWEToGLWESwitchingKeyLayout { + fn rank_in(&self) -> Rank { + Rank(1) + } + + fn digits(&self) -> Digits { + Digits(1) + } + + fn rank_out(&self) -> Rank { + self.rank_out + } + + fn rows(&self) -> Rows { + self.rows + } +} #[derive(PartialEq, Eq, Clone)] pub struct LWEToGLWESwitchingKey(pub(crate) GGLWESwitchingKey); +impl LWEInfos for LWEToGLWESwitchingKey { + fn base2k(&self) -> Base2K { + self.0.base2k() + } + + fn k(&self) -> TorusPrecision { + self.0.k() + } + + fn n(&self) -> Degree { + self.0.n() + } + + fn size(&self) -> usize { + self.0.size() + } +} + +impl GLWEInfos for LWEToGLWESwitchingKey { + fn rank(&self) -> Rank { + self.rank_out() + } +} +impl GGLWELayoutInfos for LWEToGLWESwitchingKey { + fn digits(&self) -> Digits { + self.0.digits() + } + + fn rank_in(&self) -> Rank { + self.0.rank_in() + } + + fn rank_out(&self) -> Rank { + self.0.rank_out() + } + + fn rows(&self) -> Rows { + self.0.rows() + } +} + impl fmt::Debug for LWEToGLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -22,52 +112,12 @@ impl FillUniform for LWEToGLWESwitchingKey { } } -impl Reset for LWEToGLWESwitchingKey { - fn reset(&mut self) { - self.0.reset(); - } -} - impl fmt::Display for LWEToGLWESwitchingKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(LWEToGLWESwitchingKey) {}", self.0) } } -impl Infos for LWEToGLWESwitchingKey { - type Inner = MatZnx; - - fn inner(&self) -> &Self::Inner { - self.0.inner() - } - - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { - self.0.k() - } -} - -impl LWEToGLWESwitchingKey { - pub fn digits(&self) -> usize { - self.0.digits() - } - - pub fn rank(&self) -> usize { - self.0.rank() - } - - pub fn rank_in(&self) -> usize { - self.0.rank_in() - } - - pub fn rank_out(&self) -> usize { - self.0.rank_out() - } -} - impl ReaderFrom for LWEToGLWESwitchingKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) @@ -81,7 +131,53 @@ impl WriterTo for LWEToGLWESwitchingKey { } impl LWEToGLWESwitchingKey> { - pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self { - Self(GGLWESwitchingKey::alloc(n, basek, k, rows, 1, 1, rank_out)) + pub fn alloc(infos: &A) -> Self + where + A: GGLWELayoutInfos, + { + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWEToGLWESwitchingKey" + ); + Self(GGLWESwitchingKey::alloc(infos)) + } + + pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_out: Rank) -> Self { + Self(GGLWESwitchingKey::alloc_with( + n, + base2k, + k, + rows, + Digits(1), + Rank(1), + rank_out, + )) + } + + pub fn alloc_bytes(infos: &A) -> usize + where + A: GGLWELayoutInfos, + { + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWEToGLWESwitchingKey" + ); + GGLWESwitchingKey::alloc_bytes(infos) + } + + pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_out: Rank) -> usize { + GGLWESwitchingKey::alloc_bytes_with(n, base2k, k, rows, Digits(1), Rank(1), rank_out) } } diff --git a/poulpy-core/src/layouts/mod.rs b/poulpy-core/src/layouts/mod.rs index 81b6936..a286201 100644 --- a/poulpy-core/src/layouts/mod.rs +++ b/poulpy-core/src/layouts/mod.rs @@ -8,7 +8,6 @@ mod glwe_pk; mod glwe_pt; mod glwe_sk; mod glwe_to_lwe_ksk; -mod infos; mod lwe_ct; mod lwe_ksk; mod lwe_pt; @@ -28,9 +27,195 @@ pub use glwe_pk::*; pub use glwe_pt::*; pub use glwe_sk::*; pub use glwe_to_lwe_ksk::*; -pub use infos::*; pub use lwe_ct::*; pub use lwe_ksk::*; pub use lwe_pt::*; pub use lwe_sk::*; pub use lwe_to_glwe_ksk::*; + +#[derive(Debug)] +pub enum BuildError { + MissingData, + MissingBase2K, + MissingK, + MissingDigits, + ZeroDegree, + NonPowerOfTwoDegree, + ZeroBase2K, + ZeroTorusPrecision, + ZeroCols, + ZeroLimbs, + ZeroRank, + ZeroDigits, +} + +/// Newtype over `u32` with arithmetic and comparisons against same type and `u32`. +/// Arithmetic is **saturating** (add/sub/mul) to avoid debug-overflow panics. +macro_rules! newtype_u32 { + ($name:ident) => { + #[repr(transparent)] + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] + pub struct $name(pub u32); + + // ----- Conversions ----- + impl From<$name> for u32 { + #[inline] + fn from(v: $name) -> u32 { + v.0 + } + } + impl From<$name> for usize { + #[inline] + fn from(v: $name) -> usize { + v.0 as usize + } + } + + impl From for $name { + #[inline] + fn from(v: u32) -> $name { + $name(v) + } + } + impl From for $name { + #[inline] + fn from(v: usize) -> $name { + $name(v as u32) + } + } + + // ----- Display ----- + impl ::core::fmt::Display for $name { + #[inline] + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + write!(f, "{}", self.0) + } + } + + // ===== Arithmetic (same type) ===== + impl ::core::ops::Add for $name { + type Output = $name; + #[inline] + fn add(self, rhs: $name) -> $name { + $name(self.0.saturating_add(rhs.0)) + } + } + impl ::core::ops::Sub for $name { + type Output = $name; + #[inline] + fn sub(self, rhs: $name) -> $name { + $name(self.0.saturating_sub(rhs.0)) + } + } + impl ::core::ops::Mul for $name { + type Output = $name; + #[inline] + fn mul(self, rhs: $name) -> $name { + $name(self.0.saturating_mul(rhs.0)) + } + } + + // ===== Arithmetic (with u32) ===== + impl ::core::ops::Add for $name { + type Output = $name; + #[inline] + fn add(self, rhs: u32) -> $name { + $name(self.0.saturating_add(rhs)) + } + } + impl ::core::ops::Sub for $name { + type Output = $name; + #[inline] + fn sub(self, rhs: u32) -> $name { + $name(self.0.saturating_sub(rhs)) + } + } + impl ::core::ops::Mul for $name { + type Output = $name; + #[inline] + fn mul(self, rhs: u32) -> $name { + $name(self.0.saturating_mul(rhs)) + } + } + + impl $name { + #[inline] + pub const fn as_u32(self) -> u32 { + self.0 + } + #[inline] + pub const fn as_usize(self) -> usize { + self.0 as usize + } + + #[inline] + pub fn div_ceil>(self, rhs: T) -> u32 { + self.0.div_ceil(rhs.into()) + } + } + + // Optional symmetric forms: u32 (+|-|*) $name -> $name + impl ::core::ops::Add<$name> for u32 { + type Output = $name; + #[inline] + fn add(self, rhs: $name) -> $name { + $name(self.saturating_add(rhs.0)) + } + } + impl ::core::ops::Sub<$name> for u32 { + type Output = $name; + #[inline] + fn sub(self, rhs: $name) -> $name { + $name(self.saturating_sub(rhs.0)) + } + } + impl ::core::ops::Mul<$name> for u32 { + type Output = $name; + #[inline] + fn mul(self, rhs: $name) -> $name { + $name(self.saturating_mul(rhs.0)) + } + } + + // ===== Cross-type comparisons with u32 (both directions) ===== + impl ::core::cmp::PartialEq for $name { + #[inline] + fn eq(&self, other: &u32) -> bool { + self.0 == *other + } + } + impl ::core::cmp::PartialEq<$name> for u32 { + #[inline] + fn eq(&self, other: &$name) -> bool { + *self == other.0 + } + } + + impl ::core::cmp::PartialOrd for $name { + #[inline] + fn partial_cmp(&self, other: &u32) -> Option<::core::cmp::Ordering> { + self.0.partial_cmp(other) + } + } + impl ::core::cmp::PartialOrd<$name> for u32 { + #[inline] + fn partial_cmp(&self, other: &$name) -> Option<::core::cmp::Ordering> { + self.partial_cmp(&other.0) + } + } + }; +} + +newtype_u32!(Degree); +newtype_u32!(TorusPrecision); +newtype_u32!(Base2K); +newtype_u32!(Rows); +newtype_u32!(Rank); +newtype_u32!(Digits); + +impl Degree { + pub fn log2(&self) -> usize { + let n: usize = self.0 as usize; + (usize::BITS - (n - 1).leading_zeros()) as _ + } +} diff --git a/poulpy-core/src/layouts/prepared/gglwe_atk.rs b/poulpy-core/src/layouts/prepared/gglwe_atk.rs index 2470075..2a0fb3c 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_atk.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_atk.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, }; use crate::layouts::{ - GGLWEAutomorphismKey, Infos, + Base2K, Degree, Digits, GGLWEAutomorphismKey, GGLWELayoutInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, }; @@ -14,61 +14,107 @@ pub struct GGLWEAutomorphismKeyPrepared { pub(crate) p: i64, } -impl GGLWEAutomorphismKeyPrepared, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self - where - Module: VmpPMatAlloc, - { - GGLWEAutomorphismKeyPrepared::, B> { - key: GGLWESwitchingKeyPrepared::alloc(module, basek, k, rows, digits, rank, rank), - p: 0, - } - } - - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize - where - Module: VmpPMatAllocBytes, - { - GGLWESwitchingKeyPrepared::bytes_of(module, basek, k, rows, digits, rank, rank) - } -} - -impl Infos for GGLWEAutomorphismKeyPrepared { - type Inner = VmpPMat; - - fn inner(&self) -> &Self::Inner { - self.key.inner() - } - - fn basek(&self) -> usize { - self.key.basek() - } - - fn k(&self) -> usize { - self.key.k() - } -} - impl GGLWEAutomorphismKeyPrepared { pub fn p(&self) -> i64 { self.p } +} - pub fn digits(&self) -> usize { - self.key.digits() +impl LWEInfos for GGLWEAutomorphismKeyPrepared { + fn n(&self) -> Degree { + self.key.n() } - pub fn rank(&self) -> usize { - self.key.rank() + fn base2k(&self) -> Base2K { + self.key.base2k() } - pub fn rank_in(&self) -> usize { + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} + +impl GLWEInfos for GGLWEAutomorphismKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWEAutomorphismKeyPrepared { + fn rank_in(&self) -> Rank { self.key.rank_in() } - pub fn rank_out(&self) -> usize { + fn rank_out(&self) -> Rank { self.key.rank_out() } + + fn digits(&self) -> Digits { + self.key.digits() + } + + fn rows(&self) -> Rows { + self.key.rows() + } +} + +impl GGLWEAutomorphismKeyPrepared, B> { + pub fn alloc(module: &Module, infos: &A) -> Self + where + A: GGLWELayoutInfos, + Module: VmpPMatAlloc, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEAutomorphismKeyPrepared" + ); + GGLWEAutomorphismKeyPrepared::, B> { + key: GGLWESwitchingKeyPrepared::alloc(module, infos), + p: 0, + } + } + + pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self + where + Module: VmpPMatAlloc, + { + GGLWEAutomorphismKeyPrepared { + key: GGLWESwitchingKeyPrepared::alloc_with(module, base2k, k, rows, digits, rank, rank), + p: 0, + } + } + + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GGLWELayoutInfos, + Module: VmpPMatAllocBytes, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEAutomorphismKeyPrepared" + ); + GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) + } + + pub fn alloc_bytes_with( + module: &Module, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank: Rank, + ) -> usize + where + Module: VmpPMatAllocBytes, + { + GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, rows, digits, rank, rank) + } } impl Prepare> for GGLWEAutomorphismKeyPrepared @@ -86,14 +132,7 @@ where Module: VmpPMatAlloc + VmpPrepare, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWEAutomorphismKeyPrepared, B> { - let mut atk_prepared: GGLWEAutomorphismKeyPrepared, B> = GGLWEAutomorphismKeyPrepared::alloc( - module, - self.basek(), - self.k(), - self.rows(), - self.digits(), - self.rank(), - ); + let mut atk_prepared: GGLWEAutomorphismKeyPrepared, B> = GGLWEAutomorphismKeyPrepared::alloc(module, self); atk_prepared.prepare(module, self, scratch); atk_prepared } diff --git a/poulpy-core/src/layouts/prepared/gglwe_ct.rs b/poulpy-core/src/layouts/prepared/gglwe_ct.rs index 51fe2c0..70e183d 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_ct.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_ct.rs @@ -1,115 +1,262 @@ use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, ZnxInfos}, + oep::VmpPMatAllocBytesImpl, }; use crate::layouts::{ - GGLWECiphertext, Infos, + Base2K, BuildError, Degree, Digits, GGLWECiphertext, GGLWELayoutInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, prepared::{Prepare, PrepareAlloc}, }; #[derive(PartialEq, Eq)] pub struct GGLWECiphertextPrepared { pub(crate) data: VmpPMat, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) digits: usize, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, + pub(crate) digits: Digits, +} + +impl LWEInfos for GGLWECiphertextPrepared { + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GGLWECiphertextPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWECiphertextPrepared { + fn rank_in(&self) -> Rank { + Rank(self.data.cols_in() as u32) + } + + fn rank_out(&self) -> Rank { + Rank(self.data.cols_out() as u32 - 1) + } + + fn digits(&self) -> Digits { + self.digits + } + + fn rows(&self) -> Rows { + Rows(self.data.rows() as u32) + } +} + +pub struct GGLWECiphertextPreparedBuilder { + data: Option>, + base2k: Option, + k: Option, + digits: Option, +} + +impl GGLWECiphertextPrepared { + #[inline] + pub fn builder() -> GGLWECiphertextPreparedBuilder { + GGLWECiphertextPreparedBuilder { + data: None, + base2k: None, + k: None, + digits: None, + } + } +} + +impl GGLWECiphertextPreparedBuilder, B> { + #[inline] + pub fn layout(mut self, infos: &A) -> Self + where + A: GGLWELayoutInfos, + B: VmpPMatAllocBytesImpl, + { + self.data = Some(VmpPMat::alloc( + infos.n().into(), + infos.rows().into(), + infos.rank_in().into(), + (infos.rank_out() + 1).into(), + infos.size(), + )); + self.base2k = Some(infos.base2k()); + self.k = Some(infos.k()); + self.digits = Some(infos.digits()); + self + } +} + +impl GGLWECiphertextPreparedBuilder { + #[inline] + pub fn data(mut self, data: VmpPMat) -> Self { + self.data = Some(data); + self + } + #[inline] + pub fn base2k(mut self, base2k: Base2K) -> Self { + self.base2k = Some(base2k); + self + } + #[inline] + pub fn k(mut self, k: TorusPrecision) -> Self { + self.k = Some(k); + self + } + + #[inline] + pub fn digits(mut self, digits: Digits) -> Self { + self.digits = Some(digits); + self + } + + pub fn build(self) -> Result, BuildError> { + let data: VmpPMat = self.data.ok_or(BuildError::MissingData)?; + let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; + let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; + let digits: Digits = self.digits.ok_or(BuildError::MissingDigits)?; + + if base2k == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if digits == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if k == 0_u32 { + return Err(BuildError::ZeroTorusPrecision); + } + + if data.n() == 0 { + return Err(BuildError::ZeroDegree); + } + + if data.cols() == 0 { + return Err(BuildError::ZeroCols); + } + + if data.size() == 0 { + return Err(BuildError::ZeroLimbs); + } + + Ok(GGLWECiphertextPrepared { + data, + base2k, + k, + digits, + }) + } } impl GGLWECiphertextPrepared, B> { - #[allow(clippy::too_many_arguments)] - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self + pub fn alloc(module: &Module, infos: &A) -> Self + where + A: GGLWELayoutInfos, + Module: VmpPMatAlloc, + { + debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); + Self::alloc_with( + module, + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank_in(), + infos.rank_out(), + ) + } + + pub fn alloc_with( + module: &Module, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, + ) -> Self where Module: VmpPMatAlloc, { - let size: usize = k.div_ceil(basek); + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid gglwe: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid gglwe: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); Self { - data: module.vmp_pmat_alloc(rows, rank_in, rank_out + 1, size), - basek, + data: module.vmp_pmat_alloc(rows.into(), rank_in.into(), (rank_out + 1).into(), size), k, + base2k, digits, } } - #[allow(clippy::too_many_arguments)] - pub fn bytes_of( + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GGLWELayoutInfos, + Module: VmpPMatAllocBytes, + { + debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); + Self::alloc_bytes_with( + module, + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank_in(), + infos.rank_out(), + ) + } + + pub fn alloc_bytes_with( module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, ) -> usize where Module: VmpPMatAllocBytes, { - let size: usize = k.div_ceil(basek); + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid gglwe: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid gglwe: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid gglwe: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid gglwe: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); - module.vmp_pmat_alloc_bytes(rows, rank_in, rank_out + 1, rows) - } -} - -impl Infos for GGLWECiphertextPrepared { - type Inner = VmpPMat; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGLWECiphertextPrepared { - pub fn rank(&self) -> usize { - self.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.digits - } - - pub fn rank_in(&self) -> usize { - self.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.data.cols_out() - 1 + module.vmp_pmat_alloc_bytes(rows.into(), rank_in.into(), (rank_out + 1).into(), size) } } @@ -119,8 +266,8 @@ where { fn prepare(&mut self, module: &Module, other: &GGLWECiphertext, scratch: &mut Scratch) { module.vmp_prepare(&mut self.data, &other.data, scratch); - self.basek = other.basek; self.k = other.k; + self.base2k = other.base2k; self.digits = other.digits; } } @@ -130,15 +277,7 @@ where Module: VmpPMatAlloc + VmpPrepare, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWECiphertextPrepared, B> { - let mut atk_prepared: GGLWECiphertextPrepared, B> = GGLWECiphertextPrepared::alloc( - module, - self.basek(), - self.k(), - self.rows(), - self.digits(), - self.rank_in(), - self.rank_out(), - ); + let mut atk_prepared: GGLWECiphertextPrepared, B> = GGLWECiphertextPrepared::alloc(module, self); atk_prepared.prepare(module, self, scratch); atk_prepared } diff --git a/poulpy-core/src/layouts/prepared/gglwe_ksk.rs b/poulpy-core/src/layouts/prepared/gglwe_ksk.rs index f5174d0..66198a6 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_ksk.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, }; use crate::layouts::{ - GGLWESwitchingKey, Infos, + Base2K, Degree, Digits, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, prepared::{GGLWECiphertextPrepared, Prepare, PrepareAlloc}, }; @@ -15,75 +15,103 @@ pub struct GGLWESwitchingKeyPrepared { pub(crate) sk_out_n: usize, // Degree of sk_out } +impl LWEInfos for GGLWESwitchingKeyPrepared { + fn n(&self) -> Degree { + self.key.n() + } + + fn base2k(&self) -> Base2K { + self.key.base2k() + } + + fn k(&self) -> TorusPrecision { + self.key.k() + } + + fn size(&self) -> usize { + self.key.size() + } +} + +impl GLWEInfos for GGLWESwitchingKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWESwitchingKeyPrepared { + fn rank_in(&self) -> Rank { + self.key.rank_in() + } + + fn rank_out(&self) -> Rank { + self.key.rank_out() + } + + fn digits(&self) -> Digits { + self.key.digits() + } + + fn rows(&self) -> Rows { + self.key.rows() + } +} + impl GGLWESwitchingKeyPrepared, B> { - #[allow(clippy::too_many_arguments)] - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self + pub fn alloc(module: &Module, infos: &A) -> Self where + A: GGLWELayoutInfos, Module: VmpPMatAlloc, { + debug_assert_eq!(module.n() as u32, infos.n(), "module.n() != infos.n()"); GGLWESwitchingKeyPrepared::, B> { - key: GGLWECiphertextPrepared::alloc(module, basek, k, rows, digits, rank_in, rank_out), + key: GGLWECiphertextPrepared::alloc(module, infos), sk_in_n: 0, sk_out_n: 0, } } - #[allow(clippy::too_many_arguments)] - pub fn bytes_of( + pub fn alloc_with( module: &Module, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, + ) -> Self + where + Module: VmpPMatAlloc, + { + GGLWESwitchingKeyPrepared::, B> { + key: GGLWECiphertextPrepared::alloc_with(module, base2k, k, rows, digits, rank_in, rank_out), + sk_in_n: 0, + sk_out_n: 0, + } + } + + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GGLWELayoutInfos, + Module: VmpPMatAllocBytes, + { + debug_assert_eq!(module.n() as u32, infos.n(), "module.n() != infos.n()"); + GGLWECiphertextPrepared::alloc_bytes(module, infos) + } + + pub fn alloc_bytes_with( + module: &Module, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank_in: Rank, + rank_out: Rank, ) -> usize where Module: VmpPMatAllocBytes, { - GGLWECiphertextPrepared::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) - } -} - -impl Infos for GGLWESwitchingKeyPrepared { - type Inner = VmpPMat; - - fn inner(&self) -> &Self::Inner { - self.key.inner() - } - - fn basek(&self) -> usize { - self.key.basek() - } - - fn k(&self) -> usize { - self.key.k() - } -} - -impl GGLWESwitchingKeyPrepared { - pub fn rank(&self) -> usize { - self.key.data.cols_out() - 1 - } - - pub fn rank_in(&self) -> usize { - self.key.data.cols_in() - } - - pub fn rank_out(&self) -> usize { - self.key.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.key.digits() - } - - pub fn sk_degree_in(&self) -> usize { - self.sk_in_n - } - - pub fn sk_degree_out(&self) -> usize { - self.sk_out_n + GGLWECiphertextPrepared::alloc_bytes_with(module, base2k, k, rows, digits, rank_in, rank_out) } } @@ -103,15 +131,7 @@ where Module: VmpPMatAlloc + VmpPrepare, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWESwitchingKeyPrepared, B> { - let mut atk_prepared: GGLWESwitchingKeyPrepared, B> = GGLWESwitchingKeyPrepared::alloc( - module, - self.basek(), - self.k(), - self.rows(), - self.digits(), - self.rank_in(), - self.rank_out(), - ); + let mut atk_prepared: GGLWESwitchingKeyPrepared, B> = GGLWESwitchingKeyPrepared::alloc(module, self); atk_prepared.prepare(module, self, scratch); atk_prepared } diff --git a/poulpy-core/src/layouts/prepared/gglwe_tsk.rs b/poulpy-core/src/layouts/prepared/gglwe_tsk.rs index 0e00702..1d376ba 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_tsk.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, }; use crate::layouts::{ - GGLWETensorKey, Infos, + Base2K, Degree, Digits, GGLWELayoutInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, }; @@ -13,61 +13,126 @@ pub struct GGLWETensorKeyPrepared { pub(crate) keys: Vec>, } +impl LWEInfos for GGLWETensorKeyPrepared { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for GGLWETensorKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for GGLWETensorKeyPrepared { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn digits(&self) -> Digits { + self.keys[0].digits() + } + + fn rows(&self) -> Rows { + self.keys[0].rows() + } +} + impl GGLWETensorKeyPrepared, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + pub fn alloc(module: &Module, infos: &A) -> Self + where + A: GGLWELayoutInfos, + Module: VmpPMatAlloc, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWETensorKeyPrepared" + ); + Self::alloc_with( + module, + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank_out(), + ) + } + + pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self where Module: VmpPMatAlloc, { let mut keys: Vec, B>> = Vec::new(); - let pairs: usize = (((rank + 1) * rank) >> 1).max(1); + let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKeyPrepared::alloc( - module, basek, k, rows, digits, 1, rank, + keys.push(GGLWESwitchingKeyPrepared::alloc_with( + module, + base2k, + k, + rows, + digits, + Rank(1), + rank, )); }); Self { keys } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GGLWELayoutInfos, + Module: VmpPMatAllocBytes, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWETensorKey" + ); + let rank_out: usize = infos.rank_out().into(); + let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); + pairs + * GGLWESwitchingKeyPrepared::alloc_bytes_with( + module, + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + Rank(1), + infos.rank_out(), + ) + } + + pub fn alloc_bytes_with( + module: &Module, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank: Rank, + ) -> usize where Module: VmpPMatAllocBytes, { - let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GGLWESwitchingKeyPrepared::bytes_of(module, basek, k, rows, digits, 1, rank) - } -} - -impl Infos for GGLWETensorKeyPrepared { - type Inner = VmpPMat; - - fn inner(&self) -> &Self::Inner { - self.keys[0].inner() - } - - fn basek(&self) -> usize { - self.keys[0].basek() - } - - fn k(&self) -> usize { - self.keys[0].k() - } -} - -impl GGLWETensorKeyPrepared { - pub fn rank(&self) -> usize { - self.keys[0].rank() - } - - pub fn rank_in(&self) -> usize { - self.keys[0].rank_in() - } - - pub fn rank_out(&self) -> usize { - self.keys[0].rank_out() - } - - pub fn digits(&self) -> usize { - self.keys[0].digits() + let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; + pairs * GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, rows, digits, Rank(1), rank) } } @@ -77,7 +142,7 @@ impl GGLWETensorKeyPrepared { if i > j { std::mem::swap(&mut i, &mut j); }; - let rank: usize = self.rank(); + let rank: usize = self.rank_out().into(); &mut self.keys[i * rank + j - (i * (i + 1) / 2)] } } @@ -88,7 +153,7 @@ impl GGLWETensorKeyPrepared { if i > j { std::mem::swap(&mut i, &mut j); }; - let rank: usize = self.rank(); + let rank: usize = self.rank_out().into(); &self.keys[i * rank + j - (i * (i + 1) / 2)] } } @@ -116,14 +181,7 @@ where Module: VmpPMatAlloc + VmpPrepare, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWETensorKeyPrepared, B> { - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc( - module, - self.basek(), - self.k(), - self.rows(), - self.digits(), - self.rank(), - ); + let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc(module, self); tsk_prepared.prepare(module, self, scratch); tsk_prepared } diff --git a/poulpy-core/src/layouts/prepared/ggsw_ct.rs b/poulpy-core/src/layouts/prepared/ggsw_ct.rs index 09f06da..1199ec2 100644 --- a/poulpy-core/src/layouts/prepared/ggsw_ct.rs +++ b/poulpy-core/src/layouts/prepared/ggsw_ct.rs @@ -1,99 +1,261 @@ use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, ZnxInfos}, + oep::VmpPMatAllocBytesImpl, }; use crate::layouts::{ - GGSWCiphertext, Infos, + Base2K, BuildError, Degree, Digits, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, prepared::{Prepare, PrepareAlloc}, }; #[derive(PartialEq, Eq)] pub struct GGSWCiphertextPrepared { pub(crate) data: VmpPMat, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) digits: usize, + pub(crate) k: TorusPrecision, + pub(crate) base2k: Base2K, + pub(crate) digits: Digits, } -impl GGSWCiphertextPrepared, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self - where - Module: VmpPMatAlloc, - { - let size: usize = k.div_ceil(basek); - debug_assert!(digits > 0, "invalid ggsw: `digits` == 0"); +impl LWEInfos for GGSWCiphertextPrepared { + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GGSWCiphertextPrepared { + fn rank(&self) -> Rank { + Rank(self.data.cols_out() as u32 - 1) + } +} + +impl GGSWInfos for GGSWCiphertextPrepared { + fn digits(&self) -> Digits { + self.digits + } + + fn rows(&self) -> Rows { + Rows(self.data.rows() as u32) + } +} + +pub struct GGSWCiphertextPreparedBuilder { + data: Option>, + base2k: Option, + k: Option, + digits: Option, +} + +impl GGSWCiphertextPrepared { + #[inline] + pub fn builder() -> GGSWCiphertextPreparedBuilder { + GGSWCiphertextPreparedBuilder { + data: None, + base2k: None, + k: None, + digits: None, + } + } +} + +impl GGSWCiphertextPreparedBuilder, B> { + #[inline] + pub fn layout(mut self, infos: &A) -> Self + where + A: GGSWInfos, + B: VmpPMatAllocBytesImpl, + { debug_assert!( - size > digits, - "invalid ggsw: ceil(k/basek): {} <= digits: {}", - size, - digits + infos.size() as u32 > infos.digits().0, + "invalid ggsw: ceil(k/base2k): {} <= digits: {}", + infos.size(), + infos.digits() ); assert!( - rows * digits <= size, - "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", - rows, + infos.rows().0 * infos.digits().0 <= infos.size() as u32, + "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {}", + infos.rows(), + infos.digits(), + infos.size(), + ); + + self.data = Some(VmpPMat::alloc( + infos.n().into(), + infos.rows().into(), + (infos.rank() + 1).into(), + (infos.rank() + 1).into(), + infos.size(), + )); + self.base2k = Some(infos.base2k()); + self.k = Some(infos.k()); + self.digits = Some(infos.digits()); + self + } +} + +impl GGSWCiphertextPreparedBuilder { + #[inline] + pub fn data(mut self, data: VmpPMat) -> Self { + self.data = Some(data); + self + } + #[inline] + pub fn base2k(mut self, base2k: Base2K) -> Self { + self.base2k = Some(base2k); + self + } + #[inline] + pub fn k(mut self, k: TorusPrecision) -> Self { + self.k = Some(k); + self + } + + #[inline] + pub fn digits(mut self, digits: Digits) -> Self { + self.digits = Some(digits); + self + } + + pub fn build(self) -> Result, BuildError> { + let data: VmpPMat = self.data.ok_or(BuildError::MissingData)?; + let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; + let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; + let digits: Digits = self.digits.ok_or(BuildError::MissingDigits)?; + + if base2k == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if digits == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if k == 0_u32 { + return Err(BuildError::ZeroTorusPrecision); + } + + if data.n() == 0 { + return Err(BuildError::ZeroDegree); + } + + if data.cols() == 0 { + return Err(BuildError::ZeroCols); + } + + if data.size() == 0 { + return Err(BuildError::ZeroLimbs); + } + + Ok(GGSWCiphertextPrepared { + data, + base2k, + k, digits, - size + }) + } +} + +impl GGSWCiphertextPrepared, B> { + pub fn alloc(module: &Module, infos: &A) -> Self + where + A: GGSWInfos, + Module: VmpPMatAlloc, + { + Self::alloc_with( + module, + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank(), + ) + } + + pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self + where + Module: VmpPMatAlloc, + { + let size: usize = k.0.div_ceil(base2k.0) as usize; + debug_assert!( + size as u32 > digits.0, + "invalid ggsw: ceil(k/base2k): {size} <= digits: {}", + digits.0 + ); + + assert!( + rows.0 * digits.0 <= size as u32, + "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); Self { - data: module.vmp_pmat_alloc(rows, rank + 1, rank + 1, k.div_ceil(basek)), - basek, + data: module.vmp_pmat_alloc( + rows.into(), + (rank + 1).into(), + (rank + 1).into(), + k.0.div_ceil(base2k.0) as usize, + ), k, + base2k, digits, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GGSWInfos, + Module: VmpPMatAllocBytes, + { + Self::alloc_bytes_with( + module, + infos.base2k(), + infos.k(), + infos.rows(), + infos.digits(), + infos.rank(), + ) + } + + pub fn alloc_bytes_with( + module: &Module, + base2k: Base2K, + k: TorusPrecision, + rows: Rows, + digits: Digits, + rank: Rank, + ) -> usize where Module: VmpPMatAllocBytes, { - let size: usize = k.div_ceil(basek); + let size: usize = k.0.div_ceil(base2k.0) as usize; debug_assert!( - size > digits, - "invalid ggsw: ceil(k/basek): {} <= digits: {}", - size, - digits + size as u32 > digits.0, + "invalid ggsw: ceil(k/base2k): {size} <= digits: {}", + digits.0 ); assert!( - rows * digits <= size, - "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}", - rows, - digits, - size + rows.0 * digits.0 <= size as u32, + "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}", + rows.0, + digits.0, ); - module.vmp_pmat_alloc_bytes(rows, rank + 1, rank + 1, size) - } -} - -impl Infos for GGSWCiphertextPrepared { - type Inner = VmpPMat; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GGSWCiphertextPrepared { - pub fn rank(&self) -> usize { - self.data.cols_out() - 1 - } - - pub fn digits(&self) -> usize { - self.digits + module.vmp_pmat_alloc_bytes(rows.into(), (rank + 1).into(), (rank + 1).into(), size) } } @@ -110,7 +272,7 @@ where fn prepare(&mut self, module: &Module, other: &GGSWCiphertext, scratch: &mut Scratch) { module.vmp_prepare(&mut self.data, &other.data, scratch); self.k = other.k; - self.basek = other.basek; + self.base2k = other.base2k; self.digits = other.digits; } } @@ -120,14 +282,7 @@ where Module: VmpPMatAlloc + VmpPrepare, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGSWCiphertextPrepared, B> { - let mut ggsw_prepared: GGSWCiphertextPrepared, B> = GGSWCiphertextPrepared::alloc( - module, - self.basek(), - self.k(), - self.rows(), - self.digits(), - self.rank(), - ); + let mut ggsw_prepared: GGSWCiphertextPrepared, B> = GGSWCiphertextPrepared::alloc(module, self); ggsw_prepared.prepare(module, self, scratch); ggsw_prepared } diff --git a/poulpy-core/src/layouts/prepared/glwe_ct.rs b/poulpy-core/src/layouts/prepared/glwe_ct.rs deleted file mode 100644 index 89ca3ef..0000000 --- a/poulpy-core/src/layouts/prepared/glwe_ct.rs +++ /dev/null @@ -1,177 +0,0 @@ -use poulpy_hal::{ - api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, VecZnx, WriterTo}, -}; - - -use crate::layouts::{GLWECiphertext, Infos, compressed::Decompress}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use std::fmt; - -#[derive(PartialEq, Eq, Clone)] -pub struct GLWECiphertextCompressed { - pub(crate) data: VecZnx, - pub(crate) basek: usize, - pub(crate) k: usize, - pub(crate) rank: usize, - pub(crate) seed: [u8; 32], -} - -impl fmt::Debug for GLWECiphertextCompressed { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) - } -} - -impl fmt::Display for GLWECiphertextCompressed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "GLWECiphertextCompressed: basek={} k={} rank={} seed={:?}: {}", - self.basek(), - self.k(), - self.rank, - self.seed, - self.data - ) - } -} - -impl Reset for GLWECiphertextCompressed { - fn reset(&mut self) { - self.data.reset(); - self.basek = 0; - self.k = 0; - self.rank = 0; - self.seed = [0u8; 32]; - } -} - -impl FillUniform for GLWECiphertextCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.data.fill_uniform(source); - } -} - -impl Infos for GLWECiphertextCompressed { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { - self.k - } -} - -impl GLWECiphertextCompressed { - pub fn rank(&self) -> usize { - self.rank - } -} - -impl GLWECiphertextCompressed> { - pub fn alloc(n: usize, basek: usize, k: usize, rank: usize) -> Self { - Self { - data: VecZnx::alloc(n, 1, k.div_ceil(basek)), - basek, - k, - rank, - seed: [0u8; 32], - } - } - - pub fn bytes_of(n: usize, basek: usize, k: usize) -> usize { - GLWECiphertext::bytes_of(n, basek, k, 1) - } -} - -impl ReaderFrom for GLWECiphertextCompressed { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.k = reader.read_u64::()? as usize; - self.basek = reader.read_u64::()? as usize; - self.rank = reader.read_u64::()? as usize; - reader.read_exact(&mut self.seed)?; - self.data.read_from(reader) - } -} - -impl WriterTo for GLWECiphertextCompressed { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.k as u64)?; - writer.write_u64::(self.basek as u64)?; - writer.write_u64::(self.rank as u64)?; - writer.write_all(&self.seed)?; - self.data.write_to(writer) - } -} - -impl Decompress> for GLWECiphertext { - fn decompress(&mut self, module: &Module, other: &GLWECiphertextCompressed) - where - Module: VecZnxCopy + VecZnxFillUniform, - { - #[cfg(debug_assertions)] - { - use poulpy_hal::api::ZnxInfos; - - assert_eq!( - self.n(), - other.data.n(), - "invalid receiver: self.n()={} != other.n()={}", - self.n(), - other.data.n() - ); - assert_eq!( - self.size(), - other.size(), - "invalid receiver: self.size()={} != other.size()={}", - self.size(), - other.size() - ); - assert_eq!( - self.rank(), - other.rank(), - "invalid receiver: self.rank()={} != other.rank()={}", - self.rank(), - other.rank() - ); - } - - let mut source: Source = Source::new(other.seed); - self.decompress_internal(module, other, &mut source); - } -} - -impl GLWECiphertext { - pub(crate) fn decompress_internal( - &mut self, - module: &Module, - other: &GLWECiphertextCompressed, - source: &mut Source, - ) where - DataOther: DataRef, - Module: VecZnxCopy + VecZnxFillUniform, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), other.rank()) - } - - let k: usize = other.k; - let basek: usize = other.basek; - let cols: usize = other.rank() + 1; - module.vec_znx_copy(&mut self.data, 0, &other.data, 0); - (1..cols).for_each(|i| { - module.vec_znx_fill_uniform(basek, &mut self.data, i, k, source); - }); - - self.basek = basek; - self.k = k; - } -} diff --git a/poulpy-core/src/layouts/prepared/glwe_pk.rs b/poulpy-core/src/layouts/prepared/glwe_pk.rs index 9b976c7..1ccf982 100644 --- a/poulpy-core/src/layouts/prepared/glwe_pk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_pk.rs @@ -1,12 +1,13 @@ use poulpy_hal::{ api::{VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft, ZnxInfos}, + oep::VecZnxDftAllocBytesImpl, }; use crate::{ dist::Distribution, layouts::{ - GLWEPublicKey, Infos, + Base2K, BuildError, Degree, GLWEInfos, GLWEPublicKey, LWEInfos, Rank, TorusPrecision, prepared::{Prepare, PrepareAlloc}, }, }; @@ -14,51 +15,157 @@ use crate::{ #[derive(PartialEq, Eq)] pub struct GLWEPublicKeyPrepared { pub(crate) data: VecZnxDft, - pub(crate) basek: usize, - pub(crate) k: usize, + pub(crate) base2k: Base2K, + pub(crate) k: TorusPrecision, pub(crate) dist: Distribution, } -impl Infos for GLWEPublicKeyPrepared { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data +impl LWEInfos for GLWEPublicKeyPrepared { + fn base2k(&self) -> Base2K { + self.base2k } - fn basek(&self) -> usize { - self.basek - } - - fn k(&self) -> usize { + fn k(&self) -> TorusPrecision { self.k } + + fn size(&self) -> usize { + self.data.size() + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } +} + +impl GLWEInfos for GLWEPublicKeyPrepared { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32 - 1) + } +} + +pub struct GLWEPublicKeyPreparedBuilder { + data: Option>, + base2k: Option, + k: Option, } impl GLWEPublicKeyPrepared { - pub fn rank(&self) -> usize { - self.cols() - 1 + #[inline] + pub fn builder() -> GLWEPublicKeyPreparedBuilder { + GLWEPublicKeyPreparedBuilder { + data: None, + base2k: None, + k: None, + } + } +} + +impl GLWEPublicKeyPreparedBuilder, B> { + #[inline] + pub fn layout(mut self, layout: &A) -> Self + where + A: GLWEInfos, + B: VecZnxDftAllocBytesImpl, + { + self.data = Some(VecZnxDft::alloc( + layout.n().into(), + (layout.rank() + 1).into(), + layout.size(), + )); + self.base2k = Some(layout.base2k()); + self.k = Some(layout.k()); + self + } +} + +impl GLWEPublicKeyPreparedBuilder { + #[inline] + pub fn data(mut self, data: VecZnxDft) -> Self { + self.data = Some(data); + self + } + #[inline] + pub fn base2k(mut self, base2k: Base2K) -> Self { + self.base2k = Some(base2k); + self + } + #[inline] + pub fn k(mut self, k: TorusPrecision) -> Self { + self.k = Some(k); + self + } + + pub fn build(self) -> Result, BuildError> { + let data: VecZnxDft = self.data.ok_or(BuildError::MissingData)?; + let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?; + let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?; + + if base2k == 0_u32 { + return Err(BuildError::ZeroBase2K); + } + + if k == 0_u32 { + return Err(BuildError::ZeroTorusPrecision); + } + + if data.n() == 0 { + return Err(BuildError::ZeroDegree); + } + + if data.cols() == 0 { + return Err(BuildError::ZeroCols); + } + + if data.size() == 0 { + return Err(BuildError::ZeroLimbs); + } + + Ok(GLWEPublicKeyPrepared { + data, + base2k, + k, + dist: Distribution::NONE, + }) } } impl GLWEPublicKeyPrepared, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self + pub fn alloc(module: &Module, infos: &A) -> Self + where + A: GLWEInfos, + Module: VecZnxDftAlloc, + { + debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); + Self::alloc_with(module, infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self where Module: VecZnxDftAlloc, { Self { - data: module.vec_znx_dft_alloc(rank + 1, k.div_ceil(basek)), - basek, + data: module.vec_znx_dft_alloc((rank + 1).into(), k.0.div_ceil(base2k.0) as usize), + base2k, k, dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GLWEInfos, + Module: VecZnxDftAllocBytes, + { + debug_assert_eq!(module.n(), infos.n().0 as usize, "module.n() != infos.n()"); + Self::alloc_bytes_with(module, infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize where Module: VecZnxDftAllocBytes, { - module.vec_znx_dft_alloc_bytes(rank + 1, k.div_ceil(basek)) + module.vec_znx_dft_alloc_bytes((rank + 1).into(), k.0.div_ceil(base2k.0) as usize) } } @@ -67,8 +174,7 @@ where Module: VecZnxDftAlloc + VecZnxDftApply, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEPublicKeyPrepared, B> { - let mut pk_prepared: GLWEPublicKeyPrepared, B> = - GLWEPublicKeyPrepared::alloc(module, self.basek(), self.k(), self.rank()); + let mut pk_prepared: GLWEPublicKeyPrepared, B> = GLWEPublicKeyPrepared::alloc(module, self); pk_prepared.prepare(module, self, scratch); pk_prepared } @@ -85,11 +191,11 @@ where assert_eq!(self.size(), other.size()); } - (0..self.cols()).for_each(|i| { + (0..(self.rank() + 1).into()).for_each(|i| { module.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i); }); - self.k = other.k; - self.basek = other.basek; + self.k = other.k(); + self.base2k = other.base2k(); self.dist = other.dist; } } diff --git a/poulpy-core/src/layouts/prepared/glwe_sk.rs b/poulpy-core/src/layouts/prepared/glwe_sk.rs index c6d96a7..604d617 100644 --- a/poulpy-core/src/layouts/prepared/glwe_sk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_sk.rs @@ -6,7 +6,7 @@ use poulpy_hal::{ use crate::{ dist::Distribution, layouts::{ - GLWESecret, + Base2K, Degree, GLWEInfos, GLWESecret, LWEInfos, Rank, TorusPrecision, prepared::{Prepare, PrepareAlloc}, }, }; @@ -16,36 +16,72 @@ pub struct GLWESecretPrepared { pub(crate) dist: Distribution, } +impl LWEInfos for GLWESecretPrepared { + fn base2k(&self) -> Base2K { + Base2K(0) + } + + fn k(&self) -> TorusPrecision { + TorusPrecision(0) + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} +impl GLWEInfos for GLWESecretPrepared { + fn rank(&self) -> Rank { + Rank(self.data.cols() as u32) + } +} impl GLWESecretPrepared, B> { - pub fn alloc(module: &Module, rank: usize) -> Self + pub fn alloc(module: &Module, infos: &A) -> Self + where + A: GLWEInfos, + Module: SvpPPolAlloc, + { + assert_eq!(module.n() as u32, infos.n()); + Self::alloc_with(module, infos.rank()) + } + + pub fn alloc_with(module: &Module, rank: Rank) -> Self where Module: SvpPPolAlloc, { Self { - data: module.svp_ppol_alloc(rank), + data: module.svp_ppol_alloc(rank.into()), dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, rank: usize) -> usize + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GLWEInfos, + Module: SvpPPolAllocBytes, + { + assert_eq!(module.n() as u32, infos.n()); + Self::alloc_bytes_with(module, infos.rank()) + } + + pub fn alloc_bytes_with(module: &Module, rank: Rank) -> usize where Module: SvpPPolAllocBytes, { - module.svp_ppol_alloc_bytes(rank) + module.svp_ppol_alloc_bytes(rank.into()) } } impl GLWESecretPrepared { - pub fn n(&self) -> usize { - self.data.n() + pub fn n(&self) -> Degree { + Degree(self.data.n() as u32) } - pub fn log_n(&self) -> usize { - self.data.log_n() - } - - pub fn rank(&self) -> usize { - self.data.cols() + pub fn rank(&self) -> Rank { + Rank(self.data.cols() as u32) } } @@ -54,7 +90,7 @@ where Module: SvpPrepare + SvpPPolAlloc, { fn prepare_alloc(&self, module: &Module, scratch: &mut poulpy_hal::layouts::Scratch) -> GLWESecretPrepared, B> { - let mut sk_dft: GLWESecretPrepared, B> = GLWESecretPrepared::alloc(module, self.rank()); + let mut sk_dft: GLWESecretPrepared, B> = GLWESecretPrepared::alloc(module, self); sk_dft.prepare(module, self, scratch); sk_dft } @@ -65,7 +101,7 @@ where Module: SvpPrepare, { fn prepare(&mut self, module: &Module, other: &GLWESecret, _scratch: &mut poulpy_hal::layouts::Scratch) { - (0..self.rank()).for_each(|i| { + (0..self.rank().into()).for_each(|i| { module.svp_prepare(&mut self.data, i, &other.data, i); }); self.dist = other.dist diff --git a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs index 199ffcf..8fa19d6 100644 --- a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs @@ -1,65 +1,115 @@ use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, }; use crate::layouts::{ - GLWEToLWESwitchingKey, Infos, + Base2K, Degree, Digits, GGLWELayoutInfos, GLWEInfos, GLWEToLWESwitchingKey, LWEInfos, Rank, Rows, TorusPrecision, prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, }; #[derive(PartialEq, Eq)] pub struct GLWEToLWESwitchingKeyPrepared(pub(crate) GGLWESwitchingKeyPrepared); -impl Infos for GLWEToLWESwitchingKeyPrepared { - type Inner = VmpPMat; - - fn inner(&self) -> &Self::Inner { - self.0.inner() +impl LWEInfos for GLWEToLWESwitchingKeyPrepared { + fn base2k(&self) -> Base2K { + self.0.base2k() } - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { + fn k(&self) -> TorusPrecision { self.0.k() } + + fn n(&self) -> Degree { + self.0.n() + } + + fn size(&self) -> usize { + self.0.size() + } } -impl GLWEToLWESwitchingKeyPrepared { - pub fn digits(&self) -> usize { - self.0.digits() +impl GLWEInfos for GLWEToLWESwitchingKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() } +} - pub fn rank(&self) -> usize { - self.0.rank() - } - - pub fn rank_in(&self) -> usize { +impl GGLWELayoutInfos for GLWEToLWESwitchingKeyPrepared { + fn rank_in(&self) -> Rank { self.0.rank_in() } - pub fn rank_out(&self) -> usize { + fn digits(&self) -> Digits { + self.0.digits() + } + + fn rank_out(&self) -> Rank { self.0.rank_out() } + + fn rows(&self) -> Rows { + self.0.rows() + } } impl GLWEToLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self + pub fn alloc(module: &Module, infos: &A) -> Self + where + A: GGLWELayoutInfos, + Module: VmpPMatAlloc, + { + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + Self(GGLWESwitchingKeyPrepared::alloc(module, infos)) + } + + pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_in: Rank) -> Self where Module: VmpPMatAlloc, { - Self(GGLWESwitchingKeyPrepared::alloc( - module, basek, k, rows, 1, rank_in, 1, + Self(GGLWESwitchingKeyPrepared::alloc_with( + module, + base2k, + k, + rows, + Digits(1), + rank_in, + Rank(1), )) } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GGLWELayoutInfos, + Module: VmpPMatAllocBytes, + { + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + ); + GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) + } + + pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_in: Rank) -> usize where Module: VmpPMatAllocBytes, { - GGLWESwitchingKeyPrepared::, B>::bytes_of(module, basek, k, rows, digits, rank_in, 1) + GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, rows, Digits(1), rank_in, Rank(1)) } } @@ -68,13 +118,7 @@ where Module: VmpPrepare + VmpPMatAlloc, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEToLWESwitchingKeyPrepared, B> { - let mut ksk_prepared: GLWEToLWESwitchingKeyPrepared, B> = GLWEToLWESwitchingKeyPrepared::alloc( - module, - self.0.basek(), - self.0.k(), - self.0.rows(), - self.0.rank_in(), - ); + let mut ksk_prepared: GLWEToLWESwitchingKeyPrepared, B> = GLWEToLWESwitchingKeyPrepared::alloc(module, self); ksk_prepared.prepare(module, self, scratch); ksk_prepared } diff --git a/poulpy-core/src/layouts/prepared/lwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_ksk.rs index 72e6be7..32c4fc3 100644 --- a/poulpy-core/src/layouts/prepared/lwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/lwe_ksk.rs @@ -1,65 +1,124 @@ use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, }; use crate::layouts::{ - Infos, LWESwitchingKey, + Base2K, Degree, Digits, GGLWELayoutInfos, GLWEInfos, LWEInfos, LWESwitchingKey, Rank, Rows, TorusPrecision, prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, }; #[derive(PartialEq, Eq)] pub struct LWESwitchingKeyPrepared(pub(crate) GGLWESwitchingKeyPrepared); -impl Infos for LWESwitchingKeyPrepared { - type Inner = VmpPMat; - - fn inner(&self) -> &Self::Inner { - self.0.inner() +impl LWEInfos for LWESwitchingKeyPrepared { + fn base2k(&self) -> Base2K { + self.0.base2k() } - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { + fn k(&self) -> TorusPrecision { self.0.k() } + + fn n(&self) -> Degree { + self.0.n() + } + + fn size(&self) -> usize { + self.0.size() + } +} +impl GLWEInfos for LWESwitchingKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } } -impl LWESwitchingKeyPrepared { - pub fn digits(&self) -> usize { +impl GGLWELayoutInfos for LWESwitchingKeyPrepared { + fn digits(&self) -> Digits { self.0.digits() } - pub fn rank(&self) -> usize { - self.0.rank() - } - - pub fn rank_in(&self) -> usize { + fn rank_in(&self) -> Rank { self.0.rank_in() } - pub fn rank_out(&self) -> usize { + fn rank_out(&self) -> Rank { self.0.rank_out() } + + fn rows(&self) -> Rows { + self.0.rows() + } } impl LWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize) -> Self + pub fn alloc(module: &Module, infos: &A) -> Self + where + A: GGLWELayoutInfos, + Module: VmpPMatAlloc, + { + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + Self(GGLWESwitchingKeyPrepared::alloc(module, infos)) + } + + pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rows: Rows) -> Self where Module: VmpPMatAlloc, { - Self(GGLWESwitchingKeyPrepared::alloc( - module, basek, k, rows, 1, 1, 1, + Self(GGLWESwitchingKeyPrepared::alloc_with( + module, + base2k, + k, + rows, + Digits(1), + Rank(1), + Rank(1), )) } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize) -> usize + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GGLWELayoutInfos, + Module: VmpPMatAllocBytes, + { + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWESwitchingKey" + ); + debug_assert_eq!( + infos.rank_out().0, + 1, + "rank_out > 1 is not supported for LWESwitchingKey" + ); + GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) + } + + pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rows: Rows) -> usize where Module: VmpPMatAllocBytes, { - GGLWESwitchingKeyPrepared::, B>::bytes_of(module, basek, k, rows, digits, 1, 1) + GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, rows, Digits(1), Rank(1), Rank(1)) } } @@ -68,8 +127,7 @@ where Module: VmpPrepare + VmpPMatAlloc, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> LWESwitchingKeyPrepared, B> { - let mut ksk_prepared: LWESwitchingKeyPrepared, B> = - LWESwitchingKeyPrepared::alloc(module, self.0.basek(), self.0.k(), self.0.rows()); + let mut ksk_prepared: LWESwitchingKeyPrepared, B> = LWESwitchingKeyPrepared::alloc(module, self); ksk_prepared.prepare(module, self, scratch); ksk_prepared } diff --git a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs index de5c5e1..75c0e79 100644 --- a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}, }; use crate::layouts::{ - Infos, LWEToGLWESwitchingKey, + Base2K, Degree, Digits, GGLWELayoutInfos, GLWEInfos, LWEInfos, LWEToGLWESwitchingKey, Rank, Rows, TorusPrecision, prepared::{GGLWESwitchingKeyPrepared, Prepare, PrepareAlloc}, }; @@ -12,55 +12,105 @@ use crate::layouts::{ #[derive(PartialEq, Eq)] pub struct LWEToGLWESwitchingKeyPrepared(pub(crate) GGLWESwitchingKeyPrepared); -impl Infos for LWEToGLWESwitchingKeyPrepared { - type Inner = VmpPMat; - - fn inner(&self) -> &Self::Inner { - self.0.inner() +impl LWEInfos for LWEToGLWESwitchingKeyPrepared { + fn base2k(&self) -> Base2K { + self.0.base2k() } - fn basek(&self) -> usize { - self.0.basek() - } - - fn k(&self) -> usize { + fn k(&self) -> TorusPrecision { self.0.k() } + + fn n(&self) -> Degree { + self.0.n() + } + + fn size(&self) -> usize { + self.0.size() + } } -impl LWEToGLWESwitchingKeyPrepared { - pub fn digits(&self) -> usize { +impl GLWEInfos for LWEToGLWESwitchingKeyPrepared { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWELayoutInfos for LWEToGLWESwitchingKeyPrepared { + fn digits(&self) -> Digits { self.0.digits() } - pub fn rank(&self) -> usize { - self.0.rank() - } - - pub fn rank_in(&self) -> usize { + fn rank_in(&self) -> Rank { self.0.rank_in() } - pub fn rank_out(&self) -> usize { + fn rank_out(&self) -> Rank { self.0.rank_out() } + + fn rows(&self) -> Rows { + self.0.rows() + } } impl LWEToGLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self + pub fn alloc(module: &Module, infos: &A) -> Self + where + A: GGLWELayoutInfos, + Module: VmpPMatAlloc, + { + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWEToGLWESwitchingKey" + ); + Self(GGLWESwitchingKeyPrepared::alloc(module, infos)) + } + + pub fn alloc_with(module: &Module, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_out: Rank) -> Self where Module: VmpPMatAlloc, { - Self(GGLWESwitchingKeyPrepared::alloc( - module, basek, k, rows, 1, 1, rank_out, + Self(GGLWESwitchingKeyPrepared::alloc_with( + module, + base2k, + k, + rows, + Digits(1), + Rank(1), + rank_out, )) } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_out: usize) -> usize + pub fn alloc_bytes(module: &Module, infos: &A) -> usize + where + A: GGLWELayoutInfos, + Module: VmpPMatAllocBytes, + { + debug_assert_eq!( + infos.rank_in().0, + 1, + "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + ); + debug_assert_eq!( + infos.digits().0, + 1, + "digits > 1 is not supported for LWEToGLWESwitchingKey" + ); + GGLWESwitchingKeyPrepared::alloc_bytes(module, infos) + } + + pub fn alloc_bytes_with(module: &Module, base2k: Base2K, k: TorusPrecision, rows: Rows, rank_out: Rank) -> usize where Module: VmpPMatAllocBytes, { - GGLWESwitchingKeyPrepared::, B>::bytes_of(module, basek, k, rows, digits, 1, rank_out) + GGLWESwitchingKeyPrepared::alloc_bytes_with(module, base2k, k, rows, Digits(1), Rank(1), rank_out) } } @@ -69,13 +119,7 @@ where Module: VmpPrepare + VmpPMatAlloc, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> LWEToGLWESwitchingKeyPrepared, B> { - let mut ksk_prepared: LWEToGLWESwitchingKeyPrepared, B> = LWEToGLWESwitchingKeyPrepared::alloc( - module, - self.0.basek(), - self.0.k(), - self.0.rows(), - self.0.rank_out(), - ); + let mut ksk_prepared: LWEToGLWESwitchingKeyPrepared, B> = LWEToGLWESwitchingKeyPrepared::alloc(module, self); ksk_prepared.prepare(module, self, scratch); ksk_prepared } diff --git a/poulpy-core/src/noise/gglwe_ct.rs b/poulpy-core/src/noise/gglwe_ct.rs index 8de8716..dc0593e 100644 --- a/poulpy-core/src/noise/gglwe_ct.rs +++ b/poulpy-core/src/noise/gglwe_ct.rs @@ -8,11 +8,11 @@ use poulpy_hal::{ oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, }; -use crate::layouts::{GGLWECiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared}; +use crate::layouts::{GGLWECiphertext, GGLWELayoutInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; impl GGLWECiphertext { pub fn assert_noise( - self, + &self, module: &Module, sk: &GLWESecretPrepared, pt_want: &ScalarZnx, @@ -32,15 +32,14 @@ impl GGLWECiphertext { + VecZnxSubScalarInplace, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let digits: usize = self.digits(); - let basek: usize = self.basek(); - let k: usize = self.k(); + let digits: usize = self.digits().into(); + let base2k: usize = self.base2k().into(); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k)); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self)); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self); - (0..self.rank_in()).for_each(|col_i| { - (0..self.rows()).for_each(|row_i| { + (0..self.rank_in().into()).for_each(|col_i| { + (0..self.rows().into()).for_each(|row_i| { self.at(row_i, col_i) .decrypt(module, &mut pt, sk, scratch.borrow()); @@ -52,13 +51,13 @@ impl GGLWECiphertext { col_i, ); - let noise_have: f64 = pt.data.std(basek, 0).log2(); + let noise_have: f64 = pt.data.std(base2k, 0).log2(); + + println!("noise_have: {noise_have}"); assert!( noise_have <= max_noise, - "noise_have: {} > max_noise: {}", - noise_have, - max_noise + "noise_have: {noise_have} > max_noise: {max_noise}" ); pt.data.zero(); diff --git a/poulpy-core/src/noise/ggsw_ct.rs b/poulpy-core/src/noise/ggsw_ct.rs index b1b51e4..dc156d2 100644 --- a/poulpy-core/src/noise/ggsw_ct.rs +++ b/poulpy-core/src/noise/ggsw_ct.rs @@ -3,13 +3,15 @@ use poulpy_hal::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalizeTmpBytes, VecZnxSubABInplace, + VecZnxNormalizeTmpBytes, VecZnxSubInplace, }, layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, }; -use crate::layouts::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared}; +use crate::layouts::{ + GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared, +}; impl GGSWCiphertext { pub fn assert_noise( @@ -35,24 +37,23 @@ impl GGSWCiphertext { + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace - + VecZnxSubABInplace, + + VecZnxSubInplace, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, F: Fn(usize) -> f64, { - let basek: usize = self.basek(); - let k: usize = self.k(); - let digits: usize = self.digits(); + let base2k: usize = self.base2k().into(); + let digits: usize = self.digits().into(); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self); let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes()); + ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes()); - (0..self.rank() + 1).for_each(|col_j| { - (0..self.rows()).for_each(|row_i| { + (0..(self.rank() + 1).into()).for_each(|col_j| { + (0..self.rows().into()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0); // mul with sk[col_j-1] @@ -60,17 +61,25 @@ impl GGSWCiphertext { module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); + module.vec_znx_big_normalize( + base2k, + &mut pt.data, + 0, + base2k, + &pt_big, + 0, + scratch.borrow(), + ); } self.at(row_i, col_j) .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow()); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0); + module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); - let std_pt: f64 = pt_have.data.std(basek, 0).log2(); + let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); let noise: f64 = max_noise(col_j); - assert!(std_pt <= noise, "{} > {}", std_pt, noise); + assert!(std_pt <= noise, "{std_pt} > {noise}"); pt.data.zero(); }); @@ -101,23 +110,22 @@ impl GGSWCiphertext { + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace - + VecZnxSubABInplace, + + VecZnxSubInplace, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let basek: usize = self.basek(); - let k: usize = self.k(); - let digits: usize = self.digits(); + let base2k: usize = self.base2k().into(); + let digits: usize = self.digits().into(); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self); let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes()); + ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes()); - (0..self.rank() + 1).for_each(|col_j| { - (0..self.rows()).for_each(|row_i| { + (0..(self.rank() + 1).into()).for_each(|col_j| { + (0..self.rows().into()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0); // mul with sk[col_j-1] @@ -125,16 +133,24 @@ impl GGSWCiphertext { module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); - module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); + module.vec_znx_big_normalize( + base2k, + &mut pt.data, + 0, + base2k, + &pt_big, + 0, + scratch.borrow(), + ); } self.at(row_i, col_j) .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow()); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0); + module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); - let std_pt: f64 = pt_have.data.std(basek, 0).log2(); - println!("col: {} row: {}: {}", col_j, row_i, std_pt); + let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); + println!("col: {col_j} row: {row_i}: {std_pt}"); pt.data.zero(); }); }); diff --git a/poulpy-core/src/noise/glwe_ct.rs b/poulpy-core/src/noise/glwe_ct.rs index bde5b15..c1e4f3a 100644 --- a/poulpy-core/src/noise/glwe_ct.rs +++ b/poulpy-core/src/noise/glwe_ct.rs @@ -2,17 +2,13 @@ use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSubABInplace, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSubInplace, }, layouts::{Backend, DataRef, Module, ScratchOwned}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, }; -use crate::{ - layouts::GLWEPlaintext, - layouts::prepared::GLWESecretPrepared, - layouts::{GLWECiphertext, Infos}, -}; +use crate::layouts::{GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; impl GLWECiphertext { pub fn assert_noise( @@ -33,24 +29,20 @@ impl GLWECiphertext { + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxNormalizeTmpBytes - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxNormalizeInplace, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), self.basek(), self.k()); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space( - module, - self.basek(), - self.k(), - )); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self)); self.decrypt(module, &mut pt_have, sk_prepared, scratch.borrow()); - module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - module.vec_znx_normalize_inplace(self.basek(), &mut pt_have.data, 0, scratch.borrow()); + module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt_want.data, 0); + module.vec_znx_normalize_inplace(self.base2k().into(), &mut pt_have.data, 0, scratch.borrow()); - let noise_have: f64 = pt_have.data.std(self.basek(), 0).log2(); - assert!(noise_have <= max_noise, "{} {}", noise_have, max_noise); + let noise_have: f64 = pt_have.data.std(self.base2k().into(), 0).log2(); + assert!(noise_have <= max_noise, "{noise_have} {max_noise}"); } } diff --git a/poulpy-core/src/noise/mod.rs b/poulpy-core/src/noise/mod.rs index 0592ab0..25e2b9b 100644 --- a/poulpy-core/src/noise/mod.rs +++ b/poulpy-core/src/noise/mod.rs @@ -6,7 +6,7 @@ mod glwe_ct; #[allow(dead_code)] pub(crate) fn var_noise_gglwe_product( n: f64, - basek: usize, + base2k: usize, var_xs: f64, var_msg: f64, var_a_err: f64, @@ -17,12 +17,12 @@ pub(crate) fn var_noise_gglwe_product( b_logq: usize, ) -> f64 { let a_logq: usize = a_logq.min(b_logq); - let a_cols: usize = a_logq.div_ceil(basek); + let a_cols: usize = a_logq.div_ceil(base2k); let b_scale: f64 = (b_logq as f64).exp2(); let a_scale: f64 = ((b_logq - a_logq) as f64).exp2(); - let base: f64 = (basek as f64).exp2(); + let base: f64 = (base2k as f64).exp2(); let var_base: f64 = base * base / 12f64; // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) @@ -38,7 +38,7 @@ pub(crate) fn var_noise_gglwe_product( #[allow(dead_code)] pub(crate) fn log2_std_noise_gglwe_product( n: f64, - basek: usize, + base2k: usize, var_xs: f64, var_msg: f64, var_a_err: f64, @@ -50,7 +50,7 @@ pub(crate) fn log2_std_noise_gglwe_product( ) -> f64 { let mut noise: f64 = var_noise_gglwe_product( n, - basek, + base2k, var_xs, var_msg, var_a_err, @@ -68,7 +68,7 @@ pub(crate) fn log2_std_noise_gglwe_product( #[allow(dead_code)] pub(crate) fn noise_ggsw_product( n: f64, - basek: usize, + base2k: usize, var_xs: f64, var_msg: f64, var_a0_err: f64, @@ -80,12 +80,12 @@ pub(crate) fn noise_ggsw_product( k_ggsw: usize, ) -> f64 { let a_logq: usize = k_in.min(k_ggsw); - let a_cols: usize = a_logq.div_ceil(basek); + let a_cols: usize = a_logq.div_ceil(base2k); let b_scale: f64 = (k_ggsw as f64).exp2(); let a_scale: f64 = ((k_ggsw - a_logq) as f64).exp2(); - let base: f64 = (basek as f64).exp2(); + let base: f64 = (base2k as f64).exp2(); let var_base: f64 = base * base / 12f64; // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) @@ -102,7 +102,7 @@ pub(crate) fn noise_ggsw_product( #[allow(dead_code)] pub(crate) fn noise_ggsw_keyswitch( n: f64, - basek: usize, + base2k: usize, col: usize, var_xs: f64, var_a_err: f64, @@ -118,7 +118,7 @@ pub(crate) fn noise_ggsw_keyswitch( // Initial KS for col = 0 let mut noise: f64 = var_noise_gglwe_product( n, - basek, + base2k, var_xs, var_xs, var_a_err, @@ -133,7 +133,7 @@ pub(crate) fn noise_ggsw_keyswitch( if col > 0 { noise += var_noise_gglwe_product( n, - basek, + base2k, var_xs, var_si_x_sj, var_a_err + 1f64 / 12.0, diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index e977762..b0ee4f6 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -2,40 +2,42 @@ use poulpy_hal::{ api::{ VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, - VecZnxSubABInplace, VecZnxSubBAInplace, + VecZnxSubInplace, VecZnxSubNegateInplace, }, layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxZero}, }; -use crate::layouts::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEPlaintext, Infos, SetMetaData}; +use crate::layouts::{ + GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, GLWELayoutSet, GLWEPlaintext, LWEInfos, TorusPrecision, +}; impl GLWEOperations for GLWEPlaintext where D: DataMut, - GLWEPlaintext: GLWECiphertextToMut + Infos + SetMetaData, + GLWEPlaintext: GLWECiphertextToMut + GLWEInfos, { } -impl GLWEOperations for GLWECiphertext where GLWECiphertext: GLWECiphertextToMut + Infos + SetMetaData {} +impl GLWEOperations for GLWECiphertext where GLWECiphertext: GLWECiphertextToMut + GLWEInfos {} -pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { +pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Sized { fn add(&mut self, module: &Module, a: &A, b: &B) where - A: GLWECiphertextToRef, - B: GLWECiphertextToRef, + A: GLWECiphertextToRef + GLWEInfos, + B: GLWECiphertextToRef + GLWEInfos, Module: VecZnxAdd + VecZnxCopy, { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(b.n(), self.n()); - assert_eq!(a.basek(), b.basek()); + assert_eq!(a.base2k(), b.base2k()); assert!(self.rank() >= a.rank().max(b.rank())); } - let min_col: usize = a.rank().min(b.rank()) + 1; - let max_col: usize = a.rank().max(b.rank() + 1); - let self_col: usize = self.rank() + 1; + let min_col: usize = (a.rank().min(b.rank()) + 1).into(); + let max_col: usize = (a.rank().max(b.rank() + 1)).into(); + let self_col: usize = (self.rank() + 1).into(); let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); @@ -62,26 +64,26 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { }); }); - self.set_basek(a.basek()); + self.set_basek(a.base2k()); self.set_k(set_k_binary(self, a, b)); } fn add_inplace(&mut self, module: &Module, a: &A) where - A: GLWECiphertextToRef + Infos, + A: GLWECiphertextToRef + GLWEInfos, Module: VecZnxAddInplace, { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); - assert_eq!(self.basek(), a.basek()); + assert_eq!(self.base2k(), a.base2k()); assert!(self.rank() >= a.rank()) } let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - (0..a.rank() + 1).for_each(|i| { + (0..(a.rank() + 1).into()).for_each(|i| { module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i); }); @@ -90,21 +92,21 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { fn sub(&mut self, module: &Module, a: &A, b: &B) where - A: GLWECiphertextToRef, - B: GLWECiphertextToRef, + A: GLWECiphertextToRef + GLWEInfos, + B: GLWECiphertextToRef + GLWEInfos, Module: VecZnxSub + VecZnxCopy + VecZnxNegateInplace, { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); assert_eq!(b.n(), self.n()); - assert_eq!(a.basek(), b.basek()); + assert_eq!(a.base2k(), b.base2k()); assert!(self.rank() >= a.rank().max(b.rank())); } - let min_col: usize = a.rank().min(b.rank()) + 1; - let max_col: usize = a.rank().max(b.rank() + 1); - let self_col: usize = self.rank() + 1; + let min_col: usize = (a.rank().min(b.rank()) + 1).into(); + let max_col: usize = (a.rank().max(b.rank() + 1)).into(); + let self_col: usize = (self.rank() + 1).into(); let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); @@ -132,27 +134,27 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { }); }); - self.set_basek(a.basek()); + self.set_basek(a.base2k()); self.set_k(set_k_binary(self, a, b)); } fn sub_inplace_ab(&mut self, module: &Module, a: &A) where - A: GLWECiphertextToRef + Infos, - Module: VecZnxSubABInplace, + A: GLWECiphertextToRef + GLWEInfos, + Module: VecZnxSubInplace, { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); - assert_eq!(self.basek(), a.basek()); + assert_eq!(self.base2k(), a.base2k()); assert!(self.rank() >= a.rank()) } let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - (0..a.rank() + 1).for_each(|i| { - module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i); + (0..(a.rank() + 1).into()).for_each(|i| { + module.vec_znx_sub_inplace(&mut self_mut.data, i, &a_ref.data, i); }); self.set_k(set_k_unary(self, a)) @@ -160,21 +162,21 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { fn sub_inplace_ba(&mut self, module: &Module, a: &A) where - A: GLWECiphertextToRef + Infos, - Module: VecZnxSubBAInplace, + A: GLWECiphertextToRef + GLWEInfos, + Module: VecZnxSubNegateInplace, { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); - assert_eq!(self.basek(), a.basek()); + assert_eq!(self.base2k(), a.base2k()); assert!(self.rank() >= a.rank()) } let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - (0..a.rank() + 1).for_each(|i| { - module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i); + (0..(a.rank() + 1).into()).for_each(|i| { + module.vec_znx_sub_negate_inplace(&mut self_mut.data, i, &a_ref.data, i); }); self.set_k(set_k_unary(self, a)) @@ -182,7 +184,7 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { fn rotate(&mut self, module: &Module, k: i64, a: &A) where - A: GLWECiphertextToRef + Infos, + A: GLWECiphertextToRef + GLWEInfos, Module: VecZnxRotate, { #[cfg(debug_assertions)] @@ -194,11 +196,11 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - (0..a.rank() + 1).for_each(|i| { + (0..(a.rank() + 1).into()).for_each(|i| { module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i); }); - self.set_basek(a.basek()); + self.set_basek(a.base2k()); self.set_k(set_k_unary(self, a)) } @@ -208,14 +210,14 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - (0..self_mut.rank() + 1).for_each(|i| { + (0..(self_mut.rank() + 1).into()).for_each(|i| { module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch); }); } fn mul_xp_minus_one(&mut self, module: &Module, k: i64, a: &A) where - A: GLWECiphertextToRef + Infos, + A: GLWECiphertextToRef + GLWEInfos, Module: VecZnxMulXpMinusOne, { #[cfg(debug_assertions)] @@ -227,11 +229,11 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - (0..a.rank() + 1).for_each(|i| { + (0..(a.rank() + 1).into()).for_each(|i| { module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i); }); - self.set_basek(a.basek()); + self.set_basek(a.base2k()); self.set_k(set_k_unary(self, a)) } @@ -241,14 +243,14 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - (0..self_mut.rank() + 1).for_each(|i| { + (0..(self_mut.rank() + 1).into()).for_each(|i| { module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch); }); } fn copy(&mut self, module: &Module, a: &A) where - A: GLWECiphertextToRef + Infos, + A: GLWECiphertextToRef + GLWEInfos, Module: VecZnxCopy, { #[cfg(debug_assertions)] @@ -260,27 +262,27 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - (0..self_mut.rank() + 1).for_each(|i| { + (0..(self_mut.rank() + 1).into()).for_each(|i| { module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); }); - self.set_k(a.k().min(self.size() * self.basek())); - self.set_basek(a.basek()); + self.set_k(a.k().min(self.max_k())); + self.set_basek(a.base2k()); } fn rsh(&mut self, module: &Module, k: usize, scratch: &mut Scratch) where Module: VecZnxRshInplace, { - let basek: usize = self.basek(); - (0..self.cols()).for_each(|i| { - module.vec_znx_rsh_inplace(basek, k, &mut self.to_mut().data, i, scratch); + let base2k: usize = self.base2k().into(); + (0..(self.rank() + 1).into()).for_each(|i| { + module.vec_znx_rsh_inplace(base2k, k, &mut self.to_mut().data, i, scratch); }) } fn normalize(&mut self, module: &Module, a: &A, scratch: &mut Scratch) where - A: GLWECiphertextToRef, + A: GLWECiphertextToRef + GLWEInfos, Module: VecZnxNormalize, { #[cfg(debug_assertions)] @@ -292,10 +294,18 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref(); - (0..self_mut.rank() + 1).for_each(|i| { - module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch); + (0..(self_mut.rank() + 1).into()).for_each(|i| { + module.vec_znx_normalize( + a.base2k().into(), + &mut self_mut.data, + i, + a.base2k().into(), + &a_ref.data, + i, + scratch, + ); }); - self.set_basek(a.basek()); + self.set_basek(a.base2k()); self.set_k(a.k().min(self.k())); } @@ -304,8 +314,8 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { Module: VecZnxNormalizeInplace, { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); - (0..self_mut.rank() + 1).for_each(|i| { - module.vec_znx_normalize_inplace(self_mut.basek(), &mut self_mut.data, i, scratch); + (0..(self_mut.rank() + 1).into()).for_each(|i| { + module.vec_znx_normalize_inplace(self_mut.base2k().into(), &mut self_mut.data, i, scratch); }); } } @@ -317,7 +327,7 @@ impl GLWECiphertext> { } // c = op(a, b) -fn set_k_binary(c: &impl Infos, a: &impl Infos, b: &impl Infos) -> usize { +fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { // If either operands is a ciphertext if a.rank() != 0 || b.rank() != 0 { // If a is a plaintext (but b ciphertext) @@ -338,7 +348,7 @@ fn set_k_binary(c: &impl Infos, a: &impl Infos, b: &impl Infos) -> usize { } // a = op(a, b) -fn set_k_unary(a: &impl Infos, b: &impl Infos) -> usize { +fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { if a.rank() != 0 || b.rank() != 0 { a.k().min(b.k()) } else { diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index cc1dda8..a570e66 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -1,14 +1,13 @@ use poulpy_hal::{ api::{TakeMatZnx, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, TakeVmpPMat}, - layouts::{Backend, DataRef, Scratch}, - oep::{TakeMatZnxImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, TakeVmpPMatImpl}, + layouts::{Backend, Scratch}, }; use crate::{ dist::Distribution, layouts::{ - GGLWEAutomorphismKey, GGLWECiphertext, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GLWECiphertext, GLWEPlaintext, - GLWEPublicKey, GLWESecret, Infos, + Degree, GGLWEAutomorphismKey, GGLWECiphertext, GGLWELayoutInfos, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, + GGSWInfos, GLWECiphertext, GLWEInfos, GLWEPlaintext, GLWEPublicKey, GLWESecret, Rank, prepared::{ GGLWEAutomorphismKeyPrepared, GGLWECiphertextPrepared, GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared, GGSWCiphertextPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, @@ -16,205 +15,120 @@ use crate::{ }, }; -pub trait TakeLike<'a, B: Backend, T> { - type Output; - fn take_like(&'a mut self, template: &T) -> (Self::Output, &'a mut Self); -} - pub trait TakeGLWECt { - fn take_glwe_ct(&mut self, n: usize, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self); + fn take_glwe_ct(&mut self, infos: &A) -> (GLWECiphertext<&mut [u8]>, &mut Self) + where + A: GLWEInfos; } pub trait TakeGLWECtSlice { - fn take_glwe_ct_slice( - &mut self, - size: usize, - n: usize, - basek: usize, - k: usize, - rank: usize, - ) -> (Vec>, &mut Self); + fn take_glwe_ct_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) + where + A: GLWEInfos; } pub trait TakeGLWEPt { - fn take_glwe_pt(&mut self, n: usize, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self); + fn take_glwe_pt(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) + where + A: GLWEInfos; } pub trait TakeGGLWE { - #[allow(clippy::too_many_arguments)] - fn take_gglwe( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWECiphertext<&mut [u8]>, &mut Self); + fn take_gglwe(&mut self, infos: &A) -> (GGLWECiphertext<&mut [u8]>, &mut Self) + where + A: GGLWELayoutInfos; } pub trait TakeGGLWEPrepared { - #[allow(clippy::too_many_arguments)] - fn take_gglwe_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWECiphertextPrepared<&mut [u8], B>, &mut Self); + fn take_gglwe_prepared(&mut self, infos: &A) -> (GGLWECiphertextPrepared<&mut [u8], B>, &mut Self) + where + A: GGLWELayoutInfos; } pub trait TakeGGSW { - fn take_ggsw( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGSWCiphertext<&mut [u8]>, &mut Self); + fn take_ggsw(&mut self, infos: &A) -> (GGSWCiphertext<&mut [u8]>, &mut Self) + where + A: GGSWInfos; } pub trait TakeGGSWPrepared { - fn take_ggsw_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGSWCiphertextPrepared<&mut [u8], B>, &mut Self); + fn take_ggsw_prepared(&mut self, infos: &A) -> (GGSWCiphertextPrepared<&mut [u8], B>, &mut Self) + where + A: GGSWInfos; } pub trait TakeGLWESecret { - fn take_glwe_secret(&mut self, n: usize, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); + fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self); } pub trait TakeGLWESecretPrepared { - fn take_glwe_secret_prepared(&mut self, n: usize, rank: usize) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self); + fn take_glwe_secret_prepared(&mut self, n: Degree, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self); } pub trait TakeGLWEPk { - fn take_glwe_pk(&mut self, n: usize, basek: usize, k: usize, rank: usize) -> (GLWEPublicKey<&mut [u8]>, &mut Self); + fn take_glwe_pk(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) + where + A: GLWEInfos; } pub trait TakeGLWEPkPrepared { - fn take_glwe_pk_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rank: usize, - ) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self); + fn take_glwe_pk_prepared(&mut self, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self) + where + A: GLWEInfos; } pub trait TakeGLWESwitchingKey { - #[allow(clippy::too_many_arguments)] - fn take_glwe_switching_key( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWESwitchingKey<&mut [u8]>, &mut Self); + fn take_glwe_switching_key(&mut self, infos: &A) -> (GGLWESwitchingKey<&mut [u8]>, &mut Self) + where + A: GGLWELayoutInfos; } -pub trait TakeGLWESwitchingKeyPrepared { - #[allow(clippy::too_many_arguments)] - fn take_glwe_switching_key_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self); +pub trait TakeGGLWESwitchingKeyPrepared { + fn take_gglwe_switching_key_prepared(&mut self, infos: &A) -> (GGLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self) + where + A: GGLWELayoutInfos; } pub trait TakeTensorKey { - fn take_tensor_key( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGLWETensorKey<&mut [u8]>, &mut Self); + fn take_tensor_key(&mut self, infos: &A) -> (GGLWETensorKey<&mut [u8]>, &mut Self) + where + A: GGLWELayoutInfos; } -pub trait TakeTensorKeyPrepared { - fn take_tensor_key_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGLWETensorKeyPrepared<&mut [u8], B>, &mut Self); +pub trait TakeGGLWETensorKeyPrepared { + fn take_gglwe_tensor_key_prepared(&mut self, infos: &A) -> (GGLWETensorKeyPrepared<&mut [u8], B>, &mut Self) + where + A: GGLWELayoutInfos; } -pub trait TakeAutomorphismKey { - fn take_automorphism_key( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGLWEAutomorphismKey<&mut [u8]>, &mut Self); +pub trait TakeGGLWEAutomorphismKey { + fn take_gglwe_automorphism_key(&mut self, infos: &A) -> (GGLWEAutomorphismKey<&mut [u8]>, &mut Self) + where + A: GGLWELayoutInfos; } -pub trait TakeAutomorphismKeyPrepared { - fn take_automorphism_key_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self); +pub trait TakeGGLWEAutomorphismKeyPrepared { + fn take_gglwe_automorphism_key_prepared(&mut self, infos: &A) -> (GGLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self) + where + A: GGLWELayoutInfos; } impl TakeGLWECt for Scratch where Scratch: TakeVecZnx, { - fn take_glwe_ct(&mut self, n: usize, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_vec_znx(n, rank + 1, k.div_ceil(basek)); - (GLWECiphertext { data, basek, k }, scratch) - } -} - -impl<'a, B, D> TakeLike<'a, B, GLWECiphertext> for Scratch -where - B: Backend + TakeVecZnxImpl, - D: DataRef, -{ - type Output = GLWECiphertext<&'a mut [u8]>; - - fn take_like(&'a mut self, template: &GLWECiphertext) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_vec_znx_impl(self, template.n(), template.cols(), template.size()); + fn take_glwe_ct(&mut self, infos: &A) -> (GLWECiphertext<&mut [u8]>, &mut Self) + where + A: GLWEInfos, + { + let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size()); ( - GLWECiphertext { - data, - basek: template.basek(), - k: template.k(), - }, + GLWECiphertext::builder() + .base2k(infos.base2k()) + .k(infos.k()) + .data(data) + .build() + .unwrap(), scratch, ) } @@ -224,18 +138,14 @@ impl TakeGLWECtSlice for Scratch where Scratch: TakeVecZnx, { - fn take_glwe_ct_slice( - &mut self, - size: usize, - n: usize, - basek: usize, - k: usize, - rank: usize, - ) -> (Vec>, &mut Self) { + fn take_glwe_ct_slice(&mut self, size: usize, infos: &A) -> (Vec>, &mut Self) + where + A: GLWEInfos, + { let mut scratch: &mut Scratch = self; let mut cts: Vec> = Vec::with_capacity(size); for _ in 0..size { - let (ct, new_scratch) = scratch.take_glwe_ct(n, basek, k, rank); + let (ct, new_scratch) = scratch.take_glwe_ct(infos); scratch = new_scratch; cts.push(ct); } @@ -247,27 +157,18 @@ impl TakeGLWEPt for Scratch where Scratch: TakeVecZnx, { - fn take_glwe_pt(&mut self, n: usize, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_vec_znx(n, 1, k.div_ceil(basek)); - (GLWEPlaintext { data, basek, k }, scratch) - } -} - -impl<'a, B, D> TakeLike<'a, B, GLWEPlaintext> for Scratch -where - B: Backend + TakeVecZnxImpl, - D: DataRef, -{ - type Output = GLWEPlaintext<&'a mut [u8]>; - - fn take_like(&'a mut self, template: &GLWEPlaintext) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_vec_znx_impl(self, template.n(), template.cols(), template.size()); + fn take_glwe_pt(&mut self, infos: &A) -> (GLWEPlaintext<&mut [u8]>, &mut Self) + where + A: GLWEInfos, + { + let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size()); ( - GLWEPlaintext { - data, - basek: template.basek(), - k: template.k(), - }, + GLWEPlaintext::builder() + .base2k(infos.base2k()) + .k(infos.k()) + .data(data) + .build() + .unwrap(), scratch, ) } @@ -277,58 +178,25 @@ impl TakeGGLWE for Scratch where Scratch: TakeMatZnx, { - fn take_gglwe( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWECiphertext<&mut [u8]>, &mut Self) { + fn take_gglwe(&mut self, infos: &A) -> (GGLWECiphertext<&mut [u8]>, &mut Self) + where + A: GGLWELayoutInfos, + { let (data, scratch) = self.take_mat_znx( - n, - rows.div_ceil(digits), - rank_in, - rank_out + 1, - k.div_ceil(basek), + infos.n().into(), + infos.rows().0.div_ceil(infos.digits().0) as usize, + infos.rank_in().into(), + (infos.rank_out() + 1).into(), + infos.size(), ); ( - GGLWECiphertext { - data, - basek, - k, - digits, - }, - scratch, - ) - } -} - -impl<'a, B, D> TakeLike<'a, B, GGLWECiphertext> for Scratch -where - B: Backend + TakeMatZnxImpl, - D: DataRef, -{ - type Output = GGLWECiphertext<&'a mut [u8]>; - - fn take_like(&'a mut self, template: &GGLWECiphertext) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_mat_znx_impl( - self, - template.n(), - template.rows(), - template.data.cols_in(), - template.data.cols_out(), - template.size(), - ); - ( - GGLWECiphertext { - data, - basek: template.basek(), - k: template.k(), - digits: template.digits(), - }, + GGLWECiphertext::builder() + .base2k(infos.base2k()) + .k(infos.k()) + .digits(infos.digits()) + .data(data) + .build() + .unwrap(), scratch, ) } @@ -338,58 +206,25 @@ impl TakeGGLWEPrepared for Scratch where Scratch: TakeVmpPMat, { - fn take_gglwe_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWECiphertextPrepared<&mut [u8], B>, &mut Self) { + fn take_gglwe_prepared(&mut self, infos: &A) -> (GGLWECiphertextPrepared<&mut [u8], B>, &mut Self) + where + A: GGLWELayoutInfos, + { let (data, scratch) = self.take_vmp_pmat( - n, - rows.div_ceil(digits), - rank_in, - rank_out + 1, - k.div_ceil(basek), + infos.n().into(), + infos.rows().into(), + infos.rank_in().into(), + (infos.rank_out() + 1).into(), + infos.size(), ); ( - GGLWECiphertextPrepared { - data, - basek, - k, - digits, - }, - scratch, - ) - } -} - -impl<'a, B, D> TakeLike<'a, B, GGLWECiphertextPrepared> for Scratch -where - B: Backend + TakeVmpPMatImpl, - D: DataRef, -{ - type Output = GGLWECiphertextPrepared<&'a mut [u8], B>; - - fn take_like(&'a mut self, template: &GGLWECiphertextPrepared) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_vmp_pmat_impl( - self, - template.n(), - template.rows(), - template.data.cols_in(), - template.data.cols_out(), - template.size(), - ); - ( - GGLWECiphertextPrepared { - data, - basek: template.basek(), - k: template.k(), - digits: template.digits(), - }, + GGLWECiphertextPrepared::builder() + .base2k(infos.base2k()) + .digits(infos.digits()) + .k(infos.k()) + .data(data) + .build() + .unwrap(), scratch, ) } @@ -399,57 +234,25 @@ impl TakeGGSW for Scratch where Scratch: TakeMatZnx, { - fn take_ggsw( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGSWCiphertext<&mut [u8]>, &mut Self) { + fn take_ggsw(&mut self, infos: &A) -> (GGSWCiphertext<&mut [u8]>, &mut Self) + where + A: GGSWInfos, + { let (data, scratch) = self.take_mat_znx( - n, - rows.div_ceil(digits), - rank + 1, - rank + 1, - k.div_ceil(basek), + infos.n().into(), + infos.rows().into(), + (infos.rank() + 1).into(), + (infos.rank() + 1).into(), + infos.size(), ); ( - GGSWCiphertext { - data, - basek, - k, - digits, - }, - scratch, - ) - } -} - -impl<'a, B, D> TakeLike<'a, B, GGSWCiphertext> for Scratch -where - B: Backend + TakeMatZnxImpl, - D: DataRef, -{ - type Output = GGSWCiphertext<&'a mut [u8]>; - - fn take_like(&'a mut self, template: &GGSWCiphertext) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_mat_znx_impl( - self, - template.n(), - template.rows(), - template.data.cols_in(), - template.data.cols_out(), - template.size(), - ); - ( - GGSWCiphertext { - data, - basek: template.basek(), - k: template.k(), - digits: template.digits(), - }, + GGSWCiphertext::builder() + .base2k(infos.base2k()) + .digits(infos.digits()) + .k(infos.k()) + .data(data) + .build() + .unwrap(), scratch, ) } @@ -459,57 +262,25 @@ impl TakeGGSWPrepared for Scratch where Scratch: TakeVmpPMat, { - fn take_ggsw_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGSWCiphertextPrepared<&mut [u8], B>, &mut Self) { + fn take_ggsw_prepared(&mut self, infos: &A) -> (GGSWCiphertextPrepared<&mut [u8], B>, &mut Self) + where + A: GGSWInfos, + { let (data, scratch) = self.take_vmp_pmat( - n, - rows.div_ceil(digits), - rank + 1, - rank + 1, - k.div_ceil(basek), + infos.n().into(), + infos.rows().into(), + (infos.rank() + 1).into(), + (infos.rank() + 1).into(), + infos.size(), ); ( - GGSWCiphertextPrepared { - data, - basek, - k, - digits, - }, - scratch, - ) - } -} - -impl<'a, B, D> TakeLike<'a, B, GGSWCiphertextPrepared> for Scratch -where - B: Backend + TakeVmpPMatImpl, - D: DataRef, -{ - type Output = GGSWCiphertextPrepared<&'a mut [u8], B>; - - fn take_like(&'a mut self, template: &GGSWCiphertextPrepared) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_vmp_pmat_impl( - self, - template.n(), - template.rows(), - template.data.cols_in(), - template.data.cols_out(), - template.size(), - ); - ( - GGSWCiphertextPrepared { - data, - basek: template.basek(), - k: template.k(), - digits: template.digits(), - }, + GGSWCiphertextPrepared::builder() + .base2k(infos.base2k()) + .digits(infos.digits()) + .k(infos.k()) + .data(data) + .build() + .unwrap(), scratch, ) } @@ -519,36 +290,19 @@ impl TakeGLWEPk for Scratch where Scratch: TakeVecZnx, { - fn take_glwe_pk(&mut self, n: usize, basek: usize, k: usize, rank: usize) -> (GLWEPublicKey<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_vec_znx(n, rank + 1, k.div_ceil(basek)); + fn take_glwe_pk(&mut self, infos: &A) -> (GLWEPublicKey<&mut [u8]>, &mut Self) + where + A: GLWEInfos, + { + let (data, scratch) = self.take_vec_znx(infos.n().into(), (infos.rank() + 1).into(), infos.size()); ( - GLWEPublicKey { - data, - k, - basek, - dist: Distribution::NONE, - }, - scratch, - ) - } -} - -impl<'a, B, D> TakeLike<'a, B, GLWEPublicKey> for Scratch -where - B: Backend + TakeVecZnxImpl, - D: DataRef, -{ - type Output = GLWEPublicKey<&'a mut [u8]>; - - fn take_like(&'a mut self, template: &GLWEPublicKey) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_vec_znx_impl(self, template.n(), template.cols(), template.size()); - ( - GLWEPublicKey { - data, - basek: template.basek(), - k: template.k(), - dist: template.dist, - }, + GLWEPublicKey::builder() + .base2k(infos.base2k()) + .k(infos.k()) + .base2k(infos.base2k()) + .data(data) + .build() + .unwrap(), scratch, ) } @@ -558,42 +312,18 @@ impl TakeGLWEPkPrepared for Scratch where Scratch: TakeVecZnxDft, { - fn take_glwe_pk_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rank: usize, - ) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_vec_znx_dft(n, rank + 1, k.div_ceil(basek)); + fn take_glwe_pk_prepared(&mut self, infos: &A) -> (GLWEPublicKeyPrepared<&mut [u8], B>, &mut Self) + where + A: GLWEInfos, + { + let (data, scratch) = self.take_vec_znx_dft(infos.n().into(), (infos.rank() + 1).into(), infos.size()); ( - GLWEPublicKeyPrepared { - data, - k, - basek, - dist: Distribution::NONE, - }, - scratch, - ) - } -} - -impl<'a, B, D> TakeLike<'a, B, GLWEPublicKeyPrepared> for Scratch -where - B: Backend + TakeVecZnxDftImpl, - D: DataRef, -{ - type Output = GLWEPublicKeyPrepared<&'a mut [u8], B>; - - fn take_like(&'a mut self, template: &GLWEPublicKeyPrepared) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_vec_znx_dft_impl(self, template.n(), template.cols(), template.size()); - ( - GLWEPublicKeyPrepared { - data, - basek: template.basek(), - k: template.k(), - dist: template.dist, - }, + GLWEPublicKeyPrepared::builder() + .base2k(infos.base2k()) + .k(infos.k()) + .data(data) + .build() + .unwrap(), scratch, ) } @@ -603,8 +333,8 @@ impl TakeGLWESecret for Scratch where Scratch: TakeScalarZnx, { - fn take_glwe_secret(&mut self, n: usize, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_scalar_znx(n, rank); + fn take_glwe_secret(&mut self, n: Degree, rank: Rank) -> (GLWESecret<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_scalar_znx(n.into(), rank.into()); ( GLWESecret { data, @@ -615,31 +345,12 @@ where } } -impl<'a, B, D> TakeLike<'a, B, GLWESecret> for Scratch -where - B: Backend + TakeScalarZnxImpl, - D: DataRef, -{ - type Output = GLWESecret<&'a mut [u8]>; - - fn take_like(&'a mut self, template: &GLWESecret) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_scalar_znx_impl(self, template.n(), template.rank()); - ( - GLWESecret { - data, - dist: template.dist, - }, - scratch, - ) - } -} - impl TakeGLWESecretPrepared for Scratch where Scratch: TakeSvpPPol, { - fn take_glwe_secret_prepared(&mut self, n: usize, rank: usize) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_svp_ppol(n, rank); + fn take_glwe_secret_prepared(&mut self, n: Degree, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) { + let (data, scratch) = self.take_svp_ppol(n.into(), rank.into()); ( GLWESecretPrepared { data, @@ -650,40 +361,15 @@ where } } -impl<'a, B, D> TakeLike<'a, B, GLWESecretPrepared> for Scratch -where - B: Backend + TakeSvpPPolImpl, - D: DataRef, -{ - type Output = GLWESecretPrepared<&'a mut [u8], B>; - - fn take_like(&'a mut self, template: &GLWESecretPrepared) -> (Self::Output, &'a mut Self) { - let (data, scratch) = B::take_svp_ppol_impl(self, template.n(), template.rank()); - ( - GLWESecretPrepared { - data, - dist: template.dist, - }, - scratch, - ) - } -} - impl TakeGLWESwitchingKey for Scratch where Scratch: TakeMatZnx, { - fn take_glwe_switching_key( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWESwitchingKey<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_gglwe(n, basek, k, rows, digits, rank_in, rank_out); + fn take_glwe_switching_key(&mut self, infos: &A) -> (GGLWESwitchingKey<&mut [u8]>, &mut Self) + where + A: GGLWELayoutInfos, + { + let (data, scratch) = self.take_gglwe(infos); ( GGLWESwitchingKey { key: data, @@ -695,42 +381,15 @@ where } } -impl<'a, B, D> TakeLike<'a, B, GGLWESwitchingKey> for Scratch -where - Scratch: TakeLike<'a, B, GGLWECiphertext, Output = GGLWECiphertext<&'a mut [u8]>>, - B: Backend + TakeMatZnxImpl, - D: DataRef, -{ - type Output = GGLWESwitchingKey<&'a mut [u8]>; - - fn take_like(&'a mut self, template: &GGLWESwitchingKey) -> (Self::Output, &'a mut Self) { - let (key, scratch) = self.take_like(&template.key); - ( - GGLWESwitchingKey { - key, - sk_in_n: template.sk_in_n, - sk_out_n: template.sk_out_n, - }, - scratch, - ) - } -} - -impl TakeGLWESwitchingKeyPrepared for Scratch +impl TakeGGLWESwitchingKeyPrepared for Scratch where Scratch: TakeGGLWEPrepared, { - fn take_glwe_switching_key_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> (GGLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_gglwe_prepared(n, basek, k, rows, digits, rank_in, rank_out); + fn take_gglwe_switching_key_prepared(&mut self, infos: &A) -> (GGLWESwitchingKeyPrepared<&mut [u8], B>, &mut Self) + where + A: GGLWELayoutInfos, + { + let (data, scratch) = self.take_gglwe_prepared(infos); ( GGLWESwitchingKeyPrepared { key: data, @@ -742,116 +401,60 @@ where } } -impl<'a, B, D> TakeLike<'a, B, GGLWESwitchingKeyPrepared> for Scratch -where - Scratch: TakeLike<'a, B, GGLWECiphertextPrepared, Output = GGLWECiphertextPrepared<&'a mut [u8], B>>, - B: Backend + TakeMatZnxImpl, - D: DataRef, -{ - type Output = GGLWESwitchingKeyPrepared<&'a mut [u8], B>; - - fn take_like(&'a mut self, template: &GGLWESwitchingKeyPrepared) -> (Self::Output, &'a mut Self) { - let (key, scratch) = self.take_like(&template.key); - ( - GGLWESwitchingKeyPrepared { - key, - sk_in_n: template.sk_in_n, - sk_out_n: template.sk_out_n, - }, - scratch, - ) - } -} - -impl TakeAutomorphismKey for Scratch +impl TakeGGLWEAutomorphismKey for Scratch where Scratch: TakeMatZnx, { - fn take_automorphism_key( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGLWEAutomorphismKey<&mut [u8]>, &mut Self) { - let (data, scratch) = self.take_glwe_switching_key(n, basek, k, rows, digits, rank, rank); + fn take_gglwe_automorphism_key(&mut self, infos: &A) -> (GGLWEAutomorphismKey<&mut [u8]>, &mut Self) + where + A: GGLWELayoutInfos, + { + let (data, scratch) = self.take_glwe_switching_key(infos); (GGLWEAutomorphismKey { key: data, p: 0 }, scratch) } } -impl<'a, B, D> TakeLike<'a, B, GGLWEAutomorphismKey> for Scratch +impl TakeGGLWEAutomorphismKeyPrepared for Scratch where - Scratch: TakeLike<'a, B, GGLWESwitchingKey, Output = GGLWESwitchingKey<&'a mut [u8]>>, - B: Backend + TakeMatZnxImpl, - D: DataRef, + Scratch: TakeGGLWESwitchingKeyPrepared, { - type Output = GGLWEAutomorphismKey<&'a mut [u8]>; - - fn take_like(&'a mut self, template: &GGLWEAutomorphismKey) -> (Self::Output, &'a mut Self) { - let (key, scratch) = self.take_like(&template.key); - (GGLWEAutomorphismKey { key, p: template.p }, scratch) - } -} - -impl TakeAutomorphismKeyPrepared for Scratch -where - Scratch: TakeGLWESwitchingKeyPrepared, -{ - fn take_automorphism_key_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self) { - let (data, scratch) = self.take_glwe_switching_key_prepared(n, basek, k, rows, digits, rank, rank); + fn take_gglwe_automorphism_key_prepared(&mut self, infos: &A) -> (GGLWEAutomorphismKeyPrepared<&mut [u8], B>, &mut Self) + where + A: GGLWELayoutInfos, + { + let (data, scratch) = self.take_gglwe_switching_key_prepared(infos); (GGLWEAutomorphismKeyPrepared { key: data, p: 0 }, scratch) } } -impl<'a, B, D> TakeLike<'a, B, GGLWEAutomorphismKeyPrepared> for Scratch -where - Scratch: TakeLike<'a, B, GGLWESwitchingKeyPrepared, Output = GGLWESwitchingKeyPrepared<&'a mut [u8], B>>, - B: Backend + TakeMatZnxImpl, - D: DataRef, -{ - type Output = GGLWEAutomorphismKeyPrepared<&'a mut [u8], B>; - - fn take_like(&'a mut self, template: &GGLWEAutomorphismKeyPrepared) -> (Self::Output, &'a mut Self) { - let (key, scratch) = self.take_like(&template.key); - (GGLWEAutomorphismKeyPrepared { key, p: template.p }, scratch) - } -} - impl TakeTensorKey for Scratch where Scratch: TakeMatZnx, { - fn take_tensor_key( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGLWETensorKey<&mut [u8]>, &mut Self) { + fn take_tensor_key(&mut self, infos: &A) -> (GGLWETensorKey<&mut [u8]>, &mut Self) + where + A: GGLWELayoutInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWETensorKey" + ); let mut keys: Vec> = Vec::new(); - let pairs: usize = (((rank + 1) * rank) >> 1).max(1); + let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; let mut scratch: &mut Scratch = self; + let mut ksk_infos: crate::layouts::GGLWECiphertextLayout = infos.layout(); + ksk_infos.rank_in = Rank(1); + if pairs != 0 { - let (gglwe, s) = scratch.take_glwe_switching_key(n, basek, k, rows, digits, 1, rank); + let (gglwe, s) = scratch.take_glwe_switching_key(&ksk_infos); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.take_glwe_switching_key(n, basek, k, rows, digits, 1, rank); + let (gglwe, s) = scratch.take_glwe_switching_key(&ksk_infos); scratch = s; keys.push(gglwe); } @@ -859,92 +462,38 @@ where } } -impl<'a, B, D> TakeLike<'a, B, GGLWETensorKey> for Scratch -where - Scratch: TakeLike<'a, B, GGLWESwitchingKey, Output = GGLWESwitchingKey<&'a mut [u8]>>, - B: Backend + TakeMatZnxImpl, - D: DataRef, -{ - type Output = GGLWETensorKey<&'a mut [u8]>; - - fn take_like(&'a mut self, template: &GGLWETensorKey) -> (Self::Output, &'a mut Self) { - let mut keys: Vec> = Vec::new(); - let pairs: usize = template.keys.len(); - - let mut scratch: &mut Scratch = self; - - if pairs != 0 { - let (gglwe, s) = scratch.take_like(template.at(0, 0)); - scratch = s; - keys.push(gglwe); - } - for _ in 1..pairs { - let (gglwe, s) = scratch.take_like(template.at(0, 0)); - scratch = s; - keys.push(gglwe); - } - - (GGLWETensorKey { keys }, scratch) - } -} - -impl TakeTensorKeyPrepared for Scratch +impl TakeGGLWETensorKeyPrepared for Scratch where Scratch: TakeVmpPMat, { - fn take_tensor_key_prepared( - &mut self, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank: usize, - ) -> (GGLWETensorKeyPrepared<&mut [u8], B>, &mut Self) { + fn take_gglwe_tensor_key_prepared(&mut self, infos: &A) -> (GGLWETensorKeyPrepared<&mut [u8], B>, &mut Self) + where + A: GGLWELayoutInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWETensorKeyPrepared" + ); + let mut keys: Vec> = Vec::new(); - let pairs: usize = (((rank + 1) * rank) >> 1).max(1); + let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; let mut scratch: &mut Scratch = self; + let mut ksk_infos: crate::layouts::GGLWECiphertextLayout = infos.layout(); + ksk_infos.rank_in = Rank(1); + if pairs != 0 { - let (gglwe, s) = scratch.take_glwe_switching_key_prepared(n, basek, k, rows, digits, 1, rank); + let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(&ksk_infos); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.take_glwe_switching_key_prepared(n, basek, k, rows, digits, 1, rank); + let (gglwe, s) = scratch.take_gglwe_switching_key_prepared(&ksk_infos); scratch = s; keys.push(gglwe); } (GGLWETensorKeyPrepared { keys }, scratch) } } - -impl<'a, B, D> TakeLike<'a, B, GGLWETensorKeyPrepared> for Scratch -where - Scratch: TakeLike<'a, B, GGLWESwitchingKeyPrepared, Output = GGLWESwitchingKeyPrepared<&'a mut [u8], B>>, - B: Backend + TakeMatZnxImpl, - D: DataRef, -{ - type Output = GGLWETensorKeyPrepared<&'a mut [u8], B>; - - fn take_like(&'a mut self, template: &GGLWETensorKeyPrepared) -> (Self::Output, &'a mut Self) { - let mut keys: Vec> = Vec::new(); - let pairs: usize = template.keys.len(); - - let mut scratch: &mut Scratch = self; - - if pairs != 0 { - let (gglwe, s) = scratch.take_like(template.at(0, 0)); - scratch = s; - keys.push(gglwe); - } - for _ in 1..pairs { - let (gglwe, s) = scratch.take_like(template.at(0, 0)); - scratch = s; - keys.push(gglwe); - } - - (GGLWETensorKeyPrepared { keys }, scratch) - } -} diff --git a/poulpy-core/src/tests/serialization.rs b/poulpy-core/src/tests/serialization.rs index afc4c5a..fc0ee7d 100644 --- a/poulpy-core/src/tests/serialization.rs +++ b/poulpy-core/src/tests/serialization.rs @@ -1,8 +1,8 @@ use poulpy_hal::test_suite::serialization::test_reader_writer_interface; use crate::layouts::{ - GGLWEAutomorphismKey, GGLWECiphertext, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GLWECiphertext, - GLWEToLWESwitchingKey, LWECiphertext, LWESwitchingKey, LWEToGLWESwitchingKey, + Base2K, Degree, Digits, GGLWEAutomorphismKey, GGLWECiphertext, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, + GLWECiphertext, GLWEToLWESwitchingKey, LWECiphertext, LWESwitchingKey, LWEToGLWESwitchingKey, Rank, Rows, TorusPrecision, compressed::{ GGLWEAutomorphismKeyCompressed, GGLWECiphertextCompressed, GGLWESwitchingKeyCompressed, GGLWETensorKeyCompressed, GGSWCiphertextCompressed, GLWECiphertextCompressed, GLWEToLWESwitchingKeyCompressed, LWECiphertextCompressed, @@ -10,130 +10,135 @@ use crate::layouts::{ }, }; -const N_GLWE: usize = 64; -const N_LWE: usize = 32; -const BASEK: usize = 12; -const K: usize = 33; -const ROWS: usize = 2; -const RANK: usize = 2; -const DIGITS: usize = 1; +const N_GLWE: Degree = Degree(64); +const N_LWE: Degree = Degree(32); +const BASE2K: Base2K = Base2K(12); +const K: TorusPrecision = TorusPrecision(33); +const ROWS: Rows = Rows(3); +const RANK: Rank = Rank(2); +const DIGITS: Digits = Digits(1); #[test] fn glwe_serialization() { - let original: GLWECiphertext> = GLWECiphertext::alloc(N_GLWE, BASEK, K, RANK); + let original: GLWECiphertext> = GLWECiphertext::alloc_with(N_GLWE, BASE2K, K, RANK); poulpy_hal::test_suite::serialization::test_reader_writer_interface(original); } #[test] fn glwe_compressed_serialization() { - let original: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(N_GLWE, BASEK, K, RANK); + let original: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc_with(N_GLWE, BASE2K, K, RANK); test_reader_writer_interface(original); } #[test] fn lwe_serialization() { - let original: LWECiphertext> = LWECiphertext::alloc(N_LWE, BASEK, K); + let original: LWECiphertext> = LWECiphertext::alloc_with(N_LWE, BASE2K, K); test_reader_writer_interface(original); } #[test] fn lwe_compressed_serialization() { - let original: LWECiphertextCompressed> = LWECiphertextCompressed::alloc(BASEK, K); + let original: LWECiphertextCompressed> = LWECiphertextCompressed::alloc_with(BASE2K, K); test_reader_writer_interface(original); } #[test] fn test_gglwe_serialization() { - let original: GGLWECiphertext> = GGLWECiphertext::alloc(1024, 12, 54, 3, 1, 2, 2); + let original: GGLWECiphertext> = GGLWECiphertext::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK, RANK); test_reader_writer_interface(original); } #[test] fn test_gglwe_compressed_serialization() { - let original: GGLWECiphertextCompressed> = GGLWECiphertextCompressed::alloc(1024, 12, 54, 3, 1, 2, 2); + let original: GGLWECiphertextCompressed> = + GGLWECiphertextCompressed::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK, RANK); test_reader_writer_interface(original); } #[test] fn test_glwe_switching_key_serialization() { - let original: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(1024, 12, 54, 3, 1, 2, 2); + let original: GGLWESwitchingKey> = GGLWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK, RANK); test_reader_writer_interface(original); } #[test] fn test_glwe_switching_key_compressed_serialization() { - let original: GGLWESwitchingKeyCompressed> = GGLWESwitchingKeyCompressed::alloc(1024, 12, 54, 3, 1, 2, 2); + let original: GGLWESwitchingKeyCompressed> = + GGLWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK, RANK); test_reader_writer_interface(original); } #[test] fn test_automorphism_key_serialization() { - let original: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(1024, 12, 54, 3, 1, 2); + let original: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK); test_reader_writer_interface(original); } #[test] fn test_automorphism_key_compressed_serialization() { - let original: GGLWEAutomorphismKeyCompressed> = GGLWEAutomorphismKeyCompressed::alloc(1024, 12, 54, 3, 1, 2); + let original: GGLWEAutomorphismKeyCompressed> = + GGLWEAutomorphismKeyCompressed::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK); test_reader_writer_interface(original); } #[test] fn test_tensor_key_serialization() { - let original: GGLWETensorKey> = GGLWETensorKey::alloc(1024, 12, 54, 3, 1, 2); + let original: GGLWETensorKey> = GGLWETensorKey::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK); test_reader_writer_interface(original); } #[test] fn test_tensor_key_compressed_serialization() { - let original: GGLWETensorKeyCompressed> = GGLWETensorKeyCompressed::alloc(1024, 12, 54, 3, 1, 2); + let original: GGLWETensorKeyCompressed> = GGLWETensorKeyCompressed::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK); test_reader_writer_interface(original); } #[test] fn glwe_to_lwe_switching_key_serialization() { - let original: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(N_GLWE, BASEK, K, ROWS, RANK); + let original: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, ROWS, RANK); test_reader_writer_interface(original); } #[test] fn glwe_to_lwe_switching_key_compressed_serialization() { - let original: GLWEToLWESwitchingKeyCompressed> = GLWEToLWESwitchingKeyCompressed::alloc(N_GLWE, BASEK, K, ROWS, RANK); + let original: GLWEToLWESwitchingKeyCompressed> = + GLWEToLWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, ROWS, RANK); test_reader_writer_interface(original); } #[test] fn lwe_to_glwe_switching_key_serialization() { - let original: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(N_GLWE, BASEK, K, ROWS, RANK); + let original: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, ROWS, RANK); test_reader_writer_interface(original); } #[test] fn lwe_to_glwe_switching_key_compressed_serialization() { - let original: LWEToGLWESwitchingKeyCompressed> = LWEToGLWESwitchingKeyCompressed::alloc(N_GLWE, BASEK, K, ROWS, RANK); + let original: LWEToGLWESwitchingKeyCompressed> = + LWEToGLWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, ROWS, RANK); test_reader_writer_interface(original); } #[test] fn lwe_switching_key_serialization() { - let original: LWESwitchingKey> = LWESwitchingKey::alloc(N_GLWE, BASEK, K, ROWS); + let original: LWESwitchingKey> = LWESwitchingKey::alloc_with(N_GLWE, BASE2K, K, ROWS); test_reader_writer_interface(original); } #[test] fn lwe_switching_key_compressed_serialization() { - let original: LWESwitchingKeyCompressed> = LWESwitchingKeyCompressed::alloc(N_GLWE, BASEK, K, ROWS); + let original: LWESwitchingKeyCompressed> = LWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, ROWS); test_reader_writer_interface(original); } #[test] fn ggsw_serialization() { - let original: GGSWCiphertext> = GGSWCiphertext::alloc(N_GLWE, BASEK, K, ROWS, DIGITS, RANK); + let original: GGSWCiphertext> = GGSWCiphertext::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK); test_reader_writer_interface(original); } #[test] fn ggsw_compressed_serialization() { - let original: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc(N_GLWE, BASEK, K, ROWS, DIGITS, RANK); + let original: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc_with(N_GLWE, BASE2K, K, ROWS, DIGITS, RANK); test_reader_writer_interface(original); } diff --git a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs index 415c353..0da77c4 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, @@ -18,7 +18,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GLWEPlaintext, GLWESecret, Infos, + GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWELayoutInfos, GLWEPlaintext, GLWESecret, prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, }, noise::log2_std_noise_gglwe_product, @@ -47,7 +47,7 @@ where + SvpApplyDftToDftInplace + VecZnxAddScalarInplace + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -67,40 +67,70 @@ where + TakeSvpPPolImpl + TakeVecZnxBigImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 60; let k_out: usize = 40; - let digits: usize = k_in.div_ceil(basek); - let p0 = -1; - let p1 = -5; - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_apply: usize = (digits + di) * basek; + let digits: usize = k_in.div_ceil(base2k); + let p0: i64 = -1; + let p1: i64 = -5; + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_apply: usize = (digits + di) * base2k; let n: usize = module.n(); let digits_in: usize = 1; - let rows_in: usize = k_in / (basek * di); - let rows_out: usize = k_out / (basek * di); - let rows_apply: usize = k_in.div_ceil(basek * di); + let rows_in: usize = k_in / (base2k * di); + let rows_out: usize = k_out / (base2k * di); + let rows_apply: usize = k_in.div_ceil(base2k * di); - let mut auto_key_in: GGLWEAutomorphismKey> = - GGLWEAutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_out: GGLWEAutomorphismKey> = - GGLWEAutomorphismKey::alloc(n, basek, k_out, rows_out, digits_in, rank); - let mut auto_key_apply: GGLWEAutomorphismKey> = - GGLWEAutomorphismKey::alloc(n, basek, k_apply, rows_apply, di, rank); + let auto_key_in_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rows: rows_in.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let auto_key_out_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows_out.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let auto_key_apply_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_apply.into(), + rows: rows_apply.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut auto_key_in: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_in_infos); + let mut auto_key_out: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_out_infos); + let mut auto_key_apply: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_apply_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_apply, rank) - | GGLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_in, k_apply, di, rank), + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_in_infos) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_apply_infos) + | GGLWEAutomorphismKey::automorphism_scratch_space( + module, + &auto_key_out_infos, + &auto_key_in_infos, + &auto_key_apply_infos, + ), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&auto_key_in); sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 @@ -124,7 +154,7 @@ where ); let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_apply, rows_apply, di, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_apply_infos); auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); @@ -136,11 +166,11 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&auto_key_out_infos); - let mut sk_auto: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_auto: GLWESecret> = GLWESecret::alloc(&auto_key_out_infos); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk - (0..rank).for_each(|i| { + for i in 0..rank { module.vec_znx_automorphism( module.galois_element_inv(p0 * p1), &mut sk_auto.data.as_vec_znx_mut(), @@ -148,12 +178,12 @@ where &sk.data.as_vec_znx(), i, ); - }); + } let sk_auto_dft: GLWESecretPrepared, B> = sk_auto.prepare_alloc(module, scratch.borrow()); - (0..auto_key_out.rank_in()).for_each(|col_i| { - (0..auto_key_out.rows()).for_each(|row_i| { + (0..auto_key_out.rank_in().into()).for_each(|col_i| { + (0..auto_key_out.rows().into()).for_each(|row_i| { auto_key_out .at(row_i, col_i) .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); @@ -166,10 +196,10 @@ where col_i, ); - let noise_have: f64 = pt.data.std(basek, 0).log2(); + let noise_have: f64 = pt.data.std(base2k, 0).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( n as f64, - basek * di, + base2k * di, 0.5, 0.5, 0f64, @@ -182,14 +212,13 @@ where assert!( noise_have < noise_want + 0.5, - "{} {}", - noise_have, - noise_want + "{noise_have} {}", + noise_want + 0.5 ); }); }); - }); - }); + } + } } #[allow(clippy::too_many_arguments)] @@ -202,7 +231,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -249,36 +278,53 @@ where + TakeSvpPPolImpl + TakeVecZnxBigImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); + let digits: usize = k_in.div_ceil(base2k); let p0: i64 = -1; let p1: i64 = -5; - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_apply: usize = (digits + di) * basek; + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_apply: usize = (digits + di) * base2k; let n: usize = module.n(); let digits_in: usize = 1; - let rows_in: usize = k_in / (basek * di); - let rows_apply: usize = k_in.div_ceil(basek * di); + let rows_in: usize = k_in / (base2k * di); + let rows_apply: usize = k_in.div_ceil(base2k * di); - let mut auto_key: GGLWEAutomorphismKey> = - GGLWEAutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_apply: GGLWEAutomorphismKey> = - GGLWEAutomorphismKey::alloc(n, basek, k_apply, rows_apply, di, rank); + let auto_key_layout: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rows: rows_in.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let auto_key_apply_layout: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_apply.into(), + rows: rows_apply.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_layout); + let mut auto_key_apply: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_apply_layout); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_apply, rank) - | GGLWEAutomorphismKey::automorphism_inplace_scratch_space(module, basek, k_in, k_apply, di, rank), + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key_apply) + | GGLWEAutomorphismKey::automorphism_inplace_scratch_space(module, &auto_key, &auto_key_apply), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&auto_key); sk.fill_ternary_prob(0.5, &mut source_xs); // gglwe_{s1}(s0) = s0 -> s1 @@ -302,19 +348,19 @@ where ); let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_apply, rows_apply, di, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_apply_layout); auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key.automorphism_inplace(module, &auto_key_apply_prepared, scratch.borrow()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&auto_key); - let mut sk_auto: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_auto: GLWESecret> = GLWESecret::alloc(&auto_key); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk - (0..rank).for_each(|i| { + for i in 0..rank { module.vec_znx_automorphism( module.galois_element_inv(p0 * p1), &mut sk_auto.data.as_vec_znx_mut(), @@ -322,12 +368,12 @@ where &sk.data.as_vec_znx(), i, ); - }); + } let sk_auto_dft: GLWESecretPrepared, B> = sk_auto.prepare_alloc(module, scratch.borrow()); - (0..auto_key.rank_in()).for_each(|col_i| { - (0..auto_key.rows()).for_each(|row_i| { + (0..auto_key.rank_in().into()).for_each(|col_i| { + (0..auto_key.rows().into()).for_each(|row_i| { auto_key .at(row_i, col_i) .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); @@ -339,10 +385,10 @@ where col_i, ); - let noise_have: f64 = pt.data.std(basek, 0).log2(); + let noise_have: f64 = pt.data.std(base2k, 0).log2(); let noise_want: f64 = log2_std_noise_gglwe_product( n as f64, - basek * di, + base2k * di, 0.5, 0.5, 0f64, @@ -355,12 +401,11 @@ where assert!( noise_have < noise_want + 0.5, - "{} {}", - noise_have, - noise_want + "{noise_have} {}", + noise_want + 0.5 ); }); }); - }); - }); + } + } } diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs index a342cd3..90c53d6 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -5,7 +5,7 @@ use poulpy_hal::{ VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScalarZnx, ScratchOwned}, @@ -19,13 +19,12 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GGLWETensorKey, GGSWCiphertext, GLWESecret, + GGLWEAutomorphismKey, GGLWETensorKey, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, }, noise::noise_ggsw_keyswitch, }; -#[allow(clippy::too_many_arguments)] pub fn test_ggsw_automorphism(module: &Module) where Module: VecZnxDftAllocBytes @@ -46,7 +45,7 @@ where + SvpPPolAlloc + VecZnxAddScalarInplace + VecZnxCopy - + VecZnxSubABInplace + + VecZnxSubInplace + VmpPMatAlloc + VmpPrepare + VmpApplyDftToDftTmpBytes @@ -76,26 +75,63 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 54; - let digits: usize = k_in.div_ceil(basek); + let digits: usize = k_in.div_ceil(base2k); let p: i64 = -5; - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; + + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_ksk: usize = k_in + base2k * di; let k_tsk: usize = k_ksk; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * di); - let rows_in: usize = k_in.div_euclid(basek * di); + let rows: usize = k_in.div_ceil(base2k * di); + let rows_in: usize = k_in.div_euclid(base2k * di); let digits_in: usize = 1; - let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows_in, digits_in, rank); - let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows_in, digits_in, rank); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_tsk, rows, di, rank); - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, di, rank); + let ggsw_in_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rows: rows_in.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let ggsw_out_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows_in.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let tensor_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_tsk.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let auto_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_in_layout); + let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_layout); + let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_layout); + let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -103,15 +139,15 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::automorphism_scratch_space(module, basek, k_out, k_in, k_ksk, di, k_tsk, di, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, &ct_in) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) + | GGLWETensorKey::encrypt_sk_scratch_space(module, &tensor_key) + | GGSWCiphertext::automorphism_scratch_space(module, &ct_out, &ct_in, &auto_key, &tensor_key), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&ct_out); sk.fill_ternary_prob(var_xs, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -143,11 +179,10 @@ where ); let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, di, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = - GGLWETensorKeyPrepared::alloc(module, basek, k_tsk, rows, di, rank); + let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc(module, &tensor_key_layout); tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); ct_out.automorphism( @@ -163,7 +198,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - basek * di, + base2k * di, col_j, var_xs, 0f64, @@ -177,8 +212,8 @@ where }; ct_out.assert_noise(module, &sk_prepared, &pt_scalar, max_noise); - }); - }); + } + } } #[allow(clippy::too_many_arguments)] @@ -202,7 +237,7 @@ where + SvpPPolAlloc + VecZnxAddScalarInplace + VecZnxCopy - + VecZnxSubABInplace + + VecZnxSubInplace + VmpPMatAlloc + VmpPrepare + VmpApplyDftToDftTmpBytes @@ -233,23 +268,50 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; - let k_ct: usize = 54; - let digits: usize = k_ct.div_ceil(basek); - let p = -1; - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; + let base2k: usize = 12; + let k_out: usize = 54; + let digits: usize = k_out.div_ceil(base2k); + let p: i64 = -1; + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_ksk: usize = k_out + base2k * di; let k_tsk: usize = k_ksk; let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(di * basek); - let rows_in: usize = k_ct.div_euclid(basek * di); + let rows: usize = k_out.div_ceil(di * base2k); + let rows_in: usize = k_out.div_euclid(base2k * di); let digits_in: usize = 1; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows_in, digits_in, rank); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_tsk, rows, di, rank); - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, di, rank); + let ggsw_out_layout: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows_in.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let tensor_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_tsk.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let auto_key_layout: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_layout); + let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_layout); + let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -257,15 +319,15 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::automorphism_inplace_scratch_space(module, basek, k_ct, k_ksk, di, k_tsk, di, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, &ct) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &auto_key) + | GGLWETensorKey::encrypt_sk_scratch_space(module, &tensor_key) + | GGSWCiphertext::automorphism_inplace_scratch_space(module, &ct, &auto_key, &tensor_key), ); let var_xs: f64 = 0.5; - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&ct); sk.fill_ternary_prob(var_xs, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -297,11 +359,10 @@ where ); let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, di, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = - GGLWETensorKeyPrepared::alloc(module, basek, k_tsk, rows, di, rank); + let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc(module, &tensor_key_layout); tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow()); @@ -311,20 +372,20 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - basek * di, + base2k * di, col_j, var_xs, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_ct, + k_out, k_ksk, k_tsk, ) + 0.5 }; ct.assert_noise(module, &sk_prepared, &pt_scalar, max_noise); - }); - }); + } + } } diff --git a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs index 4eda5ff..1a52bdd 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, @@ -18,13 +18,12 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, + GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, }, noise::log2_std_noise_gglwe_product, }; -#[allow(clippy::too_many_arguments)] pub fn test_glwe_automorphism(module: &Module) where Module: VecZnxDftAllocBytes @@ -34,7 +33,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -66,45 +65,60 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); + let digits: usize = k_in.div_ceil(base2k); let p: i64 = -5; - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_ksk: usize = k_in + base2k * di; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); + let rows: usize = k_in.div_ceil(base2k * digits); - let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); + let ct_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rank: rank.into(), + }; + + let ct_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rank: rank.into(), + }; + + let autokey_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rank: rank.into(), + rows: rows.into(), + digits: di.into(), + }; + + let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&autokey_infos); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(&ct_in_infos); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(&ct_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&ct_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) - | GLWECiphertext::automorphism_scratch_space( - module, - basek, - ct_out.k(), - ct_in.k(), - autokey.k(), - digits, - rank, - ), + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &autokey) + | GLWECiphertext::decrypt_scratch_space(module, &ct_out) + | GLWECiphertext::encrypt_sk_scratch_space(module, &ct_in) + | GLWECiphertext::automorphism_scratch_space(module, &ct_out, &ct_in, &autokey), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&ct_out); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -127,14 +141,14 @@ where ); let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, &autokey_infos); autokey_prepared.prepare(module, &autokey, scratch.borrow()); ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow()); let max_noise: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek * digits, + base2k * digits, 0.5, 0.5, 0f64, @@ -148,8 +162,8 @@ where module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); ct_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); - }) - }); + } + } } #[allow(clippy::too_many_arguments)] @@ -162,7 +176,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -194,35 +208,51 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); + let base2k: usize = 12; + let k_out: usize = 60; + let digits: usize = k_out.div_ceil(base2k); let p = -5; - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_ksk: usize = k_out + base2k * di; let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * digits); + let rows: usize = k_out.div_ceil(base2k * digits); - let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + let ct_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rank: rank.into(), + }; + + let autokey_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rank: rank.into(), + rows: rows.into(), + digits: di.into(), + }; + + let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&autokey_infos); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&ct_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&ct_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::automorphism_inplace_scratch_space(module, basek, ct.k(), autokey.k(), digits, rank), + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &autokey) + | GLWECiphertext::decrypt_scratch_space(module, &ct) + | GLWECiphertext::encrypt_sk_scratch_space(module, &ct) + | GLWECiphertext::automorphism_inplace_scratch_space(module, &ct, &autokey), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&ct); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -245,27 +275,27 @@ where ); let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, &autokey); autokey_prepared.prepare(module, &autokey, scratch.borrow()); ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow()); let max_noise: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek * digits, + base2k * digits, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_ct, + k_out, k_ksk, ); module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); ct.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); - }); - }); + } + } } diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index 5d4dcbb..6e82cd7 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, }, layouts::{Backend, Module, ScratchOwned, ZnxView}, @@ -16,8 +16,9 @@ use poulpy_hal::{ }; use crate::layouts::{ - GLWECiphertext, GLWEPlaintext, GLWESecret, GLWEToLWESwitchingKey, Infos, LWECiphertext, LWEPlaintext, LWESecret, - LWEToGLWESwitchingKey, + Base2K, Degree, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, GLWEToLWESwitchingKey, + GLWEToLWESwitchingKeyLayout, LWECiphertext, LWECiphertextLayout, LWEPlaintext, LWESecret, LWEToGLWESwitchingKey, + LWEToGLWESwitchingKeyLayout, Rank, Rows, TorusPrecision, prepared::{GLWESecretPrepared, GLWEToLWESwitchingKeyPrepared, LWEToGLWESwitchingKeyPrepared, PrepareAlloc}, }; @@ -29,7 +30,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -64,30 +65,44 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let n: usize = module.n(); - let basek: usize = 17; + let n_glwe: Degree = Degree(module.n() as u32); + let n_lwe: Degree = Degree(22); - let rank: usize = 2; - - let n_lwe: usize = 22; - let k_lwe_ct: usize = 2 * basek; - let k_lwe_pt: usize = 8; - - let k_glwe_ct: usize = 3 * basek; - - let k_ksk: usize = k_lwe_ct + basek; + let rank: Rank = Rank(2); + let k_lwe_pt: TorusPrecision = TorusPrecision(8); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); + let lwe_to_glwe_infos: LWEToGLWESwitchingKeyLayout = LWEToGLWESwitchingKeyLayout { + n: n_glwe, + base2k: Base2K(17), + k: TorusPrecision(51), + rows: Rows(2), + rank_out: rank, + }; + + let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n_glwe, + base2k: Base2K(17), + k: TorusPrecision(34), + rank, + }; + + let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + n: n_lwe, + base2k: Base2K(17), + k: TorusPrecision(34), + }; + let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GLWECiphertext::from_lwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, k_glwe_ct), + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, &lwe_to_glwe_infos) + | GLWECiphertext::from_lwe_scratch_space(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) + | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); @@ -97,13 +112,13 @@ where let data: i64 = 17; - let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(&lwe_infos); lwe_pt.encode_i64(data, k_lwe_pt); - let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe); - let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(n, basek, k_ksk, lwe_ct.size(), rank); + let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(&lwe_to_glwe_infos); ksk.encrypt_sk( module, @@ -114,13 +129,13 @@ where scratch.borrow(), ); - let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_glwe_ct, rank); + let mut glwe_ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); let ksk_prepared: LWEToGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); glwe_ct.from_lwe(module, &lwe_ct, &ksk_prepared, scratch.borrow()); - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_glwe_ct); + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); glwe_ct.decrypt(module, &mut glwe_pt, &sk_glwe_prepared, scratch.borrow()); assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); @@ -134,7 +149,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -167,42 +182,56 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let n: usize = module.n(); - let basek: usize = 17; + let n_glwe: Degree = Degree(module.n() as u32); + let n_lwe: Degree = Degree(22); - let rank: usize = 2; + let rank: Rank = Rank(2); + let k_lwe_pt: TorusPrecision = TorusPrecision(8); - let n_lwe: usize = 22; - let k_lwe_ct: usize = 2 * basek; - let k_lwe_pt: usize = 8; + let glwe_to_lwe_infos: GLWEToLWESwitchingKeyLayout = GLWEToLWESwitchingKeyLayout { + n: n_glwe, + base2k: Base2K(17), + k: TorusPrecision(51), + rows: Rows(2), + rank_in: rank, + }; - let k_glwe_ct: usize = 3 * basek; + let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n_glwe, + base2k: Base2K(17), + k: TorusPrecision(34), + rank, + }; - let k_ksk: usize = k_lwe_ct + basek; + let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + n: n_lwe, + base2k: Base2K(17), + k: TorusPrecision(34), + }; let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | LWECiphertext::from_glwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, k_glwe_ct), + GLWEToLWESwitchingKey::encrypt_sk_scratch_space(module, &glwe_to_lwe_infos) + | LWECiphertext::from_glwe_scratch_space(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) + | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), ); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); - let mut sk_lwe = LWESecret::alloc(n_lwe); + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); sk_lwe.fill_ternary_prob(0.5, &mut source_xs); let data: i64 = 17; - let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_glwe_ct); + let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); glwe_pt.encode_coeff_i64(data, k_lwe_pt, 0); - let mut glwe_ct = GLWECiphertext::alloc(n, basek, k_glwe_ct, rank); + let mut glwe_ct = GLWECiphertext::alloc(&glwe_infos); glwe_ct.encrypt_sk( module, &glwe_pt, @@ -212,7 +241,7 @@ where scratch.borrow(), ); - let mut ksk: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(n, basek, k_ksk, glwe_ct.size(), rank); + let mut ksk: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(&glwe_to_lwe_infos); ksk.encrypt_sk( module, @@ -223,13 +252,13 @@ where scratch.borrow(), ); - let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + let mut lwe_ct: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); let ksk_prepared: GLWEToLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); lwe_ct.from_glwe(module, &glwe_ct, &ksk_prepared, scratch.borrow()); - let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_ct); + let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc(&lwe_infos); lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe); assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs index c263faf..717a639 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, @@ -18,7 +18,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GLWESecret, + GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWEInfos, GLWESecret, compressed::{Decompress, GGLWEAutomorphismKeyCompressed}, prepared::{GLWESecretPrepared, PrepareAlloc}, }, @@ -33,7 +33,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -67,25 +67,34 @@ where + TakeSvpPPolImpl + TakeVecZnxBigImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_ksk: usize = 60; - let digits: usize = k_ksk.div_ceil(basek) - 1; - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { + let digits: usize = k_ksk.div_ceil(base2k) - 1; + for rank in 1_usize..3 { + for di in 1..digits + 1 { let n: usize = module.n(); - let rows: usize = (k_ksk - di * basek) / (di * basek); + let rows: usize = (k_ksk - di * base2k) / (di * base2k); - let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, di, rank); + let atk_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&atk_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( - module, basek, k_ksk, rank, + module, &atk_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&atk_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let p = -5; @@ -100,7 +109,7 @@ where ); let mut sk_out: GLWESecret> = sk.clone(); - (0..atk.rank()).for_each(|i| { + (0..atk.rank().into()).for_each(|i| { module.vec_znx_automorphism( module.galois_element_inv(p), &mut sk_out.data.as_vec_znx_mut(), @@ -114,8 +123,8 @@ where atk.key .key .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); - }); - }); + } + } } pub fn test_gglwe_automorphisk_key_compressed_encrypt_sk(module: &Module) @@ -127,7 +136,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -161,25 +170,33 @@ where + TakeSvpPPolImpl + TakeVecZnxBigImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_ksk: usize = 60; - let digits: usize = k_ksk.div_ceil(basek) - 1; - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { + let digits: usize = k_ksk.div_ceil(base2k) - 1; + for rank in 1_usize..3 { + for di in 1..digits + 1 { let n: usize = module.n(); - let rows: usize = (k_ksk - di * basek) / (di * basek); + let rows: usize = (k_ksk - di * base2k) / (di * base2k); - let mut atk_compressed: GGLWEAutomorphismKeyCompressed> = - GGLWEAutomorphismKeyCompressed::alloc(n, basek, k_ksk, rows, di, rank); + let atk_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut atk_compressed: GGLWEAutomorphismKeyCompressed> = GGLWEAutomorphismKeyCompressed::alloc(&atk_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( - module, basek, k_ksk, rank, + module, &atk_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&atk_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let p = -5; @@ -189,7 +206,7 @@ where atk_compressed.encrypt_sk(module, p, &sk, seed_xa, &mut source_xe, scratch.borrow()); let mut sk_out: GLWESecret> = sk.clone(); - (0..atk_compressed.rank()).for_each(|i| { + (0..atk_compressed.rank().into()).for_each(|i| { module.vec_znx_automorphism( module.galois_element_inv(p), &mut sk_out.data.as_vec_znx_mut(), @@ -200,12 +217,12 @@ where }); let sk_out_prepared = sk_out.prepare_alloc(module, scratch.borrow()); - let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, di, rank); + let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&atk_infos); atk.decompress(module, &atk_compressed); atk.key .key .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); - }); - }); + } + } } diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs index ca4fb02..d8ba73a 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, @@ -17,7 +17,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GLWESecret, + GGLWECiphertextLayout, GGLWESwitchingKey, GLWESecret, compressed::{Decompress, GGLWESwitchingKeyCompressed}, prepared::{GLWESecretPrepared, PrepareAlloc}, }, @@ -32,7 +32,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -62,29 +62,40 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_ksk: usize = 54; - let digits: usize = k_ksk / basek; - (1..3).for_each(|rank_in| { - (1..3).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { + let digits: usize = k_ksk / base2k; + for rank_in in 1_usize..3 { + for rank_out in 1_usize..3 { + for di in 1_usize..digits + 1 { let n: usize = module.n(); - let rows: usize = (k_ksk - di * basek) / (di * basek); + let rows: usize = (k_ksk - di * base2k) / (di * base2k); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank_in, rank_out); + let gglwe_infos: GGLWECiphertextLayout = GGLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank_in: rank_in.into(), + rank_out: rank_out.into(), + }; + + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( - module, basek, k_ksk, rank_in, rank_out, + module, + &gglwe_infos, )); - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -99,9 +110,9 @@ where ksk.key .assert_noise(module, &sk_out_prepared, &sk_in.data, SIGMA); - }); - }); - }); + } + } + } } pub fn test_gglwe_switching_key_compressed_encrypt_sk(module: &Module) @@ -113,7 +124,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -143,29 +154,39 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_ksk: usize = 54; - let digits: usize = k_ksk / basek; - (1..3).for_each(|rank_in| { - (1..3).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { + let digits: usize = k_ksk / base2k; + for rank_in in 1_usize..3 { + for rank_out in 1_usize..3 { + for di in 1_usize..digits + 1 { let n: usize = module.n(); - let rows: usize = (k_ksk - di * basek) / (di * basek); + let rows: usize = (k_ksk - di * base2k) / (di * base2k); - let mut ksk_compressed: GGLWESwitchingKeyCompressed> = - GGLWESwitchingKeyCompressed::alloc(n, basek, k_ksk, rows, di, rank_in, rank_out); + let gglwe_infos: GGLWECiphertextLayout = GGLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank_in: rank_in.into(), + rank_out: rank_out.into(), + }; + + let mut ksk_compressed: GGLWESwitchingKeyCompressed> = GGLWESwitchingKeyCompressed::alloc(&gglwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space( - module, basek, k_ksk, rank_in, rank_out, + module, + &gglwe_infos, )); - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -180,12 +201,12 @@ where scratch.borrow(), ); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank_in, rank_out); + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_infos); ksk.decompress(module, &ksk_compressed); ksk.key .assert_noise(module, &sk_out_prepared, &sk_in.data, SIGMA); - }); - }); - }); + } + } + } } diff --git a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs index 1739ffb..8c29bda 100644 --- a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VmpPMatAlloc, VmpPrepare, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScalarZnx, ScratchOwned}, oep::{ @@ -17,7 +17,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGSWCiphertext, GLWESecret, + GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, compressed::{Decompress, GGSWCiphertextCompressed}, prepared::{GLWESecretPrepared, PrepareAlloc}, }, @@ -32,7 +32,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -65,15 +65,24 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k: usize = 54; - let digits: usize = k / basek; - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { + let digits: usize = k / base2k; + for rank in 1_usize..3 { + for di in 1..digits + 1 { let n: usize = module.n(); - let rows: usize = (k - di * basek) / (di * basek); + let rows: usize = (k - di * base2k) / (di * base2k); - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k, rows, di, rank); + let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -84,10 +93,11 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertext::encrypt_sk_scratch_space( - module, basek, k, rank, + module, + &ggsw_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&ggsw_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -103,8 +113,8 @@ where let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; ct.assert_noise(module, &sk_prepared, &pt_scalar, noise_f); - }); - }); + } + } } pub fn test_ggsw_compressed_encrypt_sk(module: &Module) @@ -116,7 +126,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -148,16 +158,24 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k: usize = 54; - let digits: usize = k / basek; - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { + let digits: usize = k / base2k; + for rank in 1_usize..3 { + for di in 1..digits + 1 { let n: usize = module.n(); - let rows: usize = (k - di * basek) / (di * basek); + let rows: usize = (k - di * base2k) / (di * base2k); - let mut ct_compressed: GGSWCiphertextCompressed> = - GGSWCiphertextCompressed::alloc(n, basek, k, rows, di, rank); + let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut ct_compressed: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc(&ggsw_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -167,10 +185,11 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertextCompressed::encrypt_sk_scratch_space( - module, basek, k, rank, + module, + &ggsw_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&ggsw_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -187,10 +206,10 @@ where let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k, rows, di, rank); + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); ct.decompress(module, &ct_compressed); ct.assert_noise(module, &sk_prepared, &pt_scalar, noise_f); - }); - }); + } + } } diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs index 9f77baa..a1169f6 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, + VecZnxSubInplace, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -17,7 +17,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, + GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWEPlaintextLayout, GLWEPublicKey, GLWESecret, LWEInfos, compressed::{Decompress, GLWECiphertextCompressed}, prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }, @@ -38,24 +38,10 @@ where + SvpPrepare + SvpPPolAllocBytes + SvpPPolAlloc - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + SvpPPolAllocBytes - + SvpPrepare + SvpApplyDftToDft - + VecZnxIdftApplyConsume + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -71,30 +57,44 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 8; + let base2k: usize = 8; let k_ct: usize = 54; let k_pt: usize = 30; - for rank in 1..3 { - let n = module.n(); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + for rank in 1_usize..3 { + let n: usize = module.n(); + + let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ct.into(), + rank: rank.into(), + }; + + let pt_infos: GLWEPlaintextLayout = GLWEPlaintextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_pt.into(), + }; + + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()), + GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_infos) + | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); ct.encrypt_sk( module, @@ -109,7 +109,7 @@ where pt_want.sub_inplace_ab(module, &pt_have); - let noise_have: f64 = pt_want.data.std(basek, 0) * (ct.k() as f64).exp2(); + let noise_have: f64 = pt_want.data.std(base2k, 0) * (ct.k().as_u32() as f64).exp2(); let noise_want: f64 = SIGMA; assert!(noise_have <= noise_want + 0.2); @@ -130,24 +130,10 @@ where + SvpPrepare + SvpPPolAllocBytes + SvpPPolAlloc - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + SvpPPolAllocBytes - + SvpPrepare + SvpApplyDftToDft - + VecZnxIdftApplyConsume + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -164,31 +150,45 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 8; + let base2k: usize = 8; let k_ct: usize = 54; let k_pt: usize = 30; - for rank in 1..3 { - let n = module.n(); - let mut ct_compressed: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(n, basek, k_ct, rank); + for rank in 1_usize..3 { + let n: usize = module.n(); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ct.into(), + rank: rank.into(), + }; + + let pt_infos: GLWEPlaintextLayout = GLWEPlaintextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_pt.into(), + }; + + let mut ct_compressed: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(&glwe_infos); + + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&pt_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertextCompressed::encrypt_sk_scratch_space(module, basek, k_ct) - | GLWECiphertext::decrypt_scratch_space(module, basek, k_ct), + GLWECiphertextCompressed::encrypt_sk_scratch_space(module, &glwe_infos) + | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let seed_xa: [u8; 32] = [1u8; 32]; @@ -201,20 +201,19 @@ where scratch.borrow(), ); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); ct.decompress(module, &ct_compressed); ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); pt_want.sub_inplace_ab(module, &pt_have); - let noise_have: f64 = pt_want.data.std(basek, 0) * (ct.k() as f64).exp2(); + let noise_have: f64 = pt_want.data.std(base2k, 0) * (ct.k().as_u32() as f64).exp2(); let noise_want: f64 = SIGMA; assert!( noise_have <= noise_want + 0.2, - "{} <= {}", - noise_have, + "{noise_have} <= {}", noise_want + 0.2 ); } @@ -234,24 +233,10 @@ where + SvpPrepare + SvpPPolAllocBytes + SvpPPolAlloc - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + SvpPPolAllocBytes - + SvpPrepare + SvpApplyDftToDft - + VecZnxIdftApplyConsume + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -267,27 +252,35 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 8; + let base2k: usize = 8; let k_ct: usize = 54; - for rank in 1..3 { - let n = module.n(); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + for rank in 1_usize..3 { + let n: usize = module.n(); + + let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ct.into(), + rank: rank.into(), + }; + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::decrypt_scratch_space(module, basek, k_ct) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct), + GLWECiphertext::decrypt_scratch_space(module, &glwe_infos) + | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); ct.encrypt_zero_sk( module, @@ -298,7 +291,7 @@ where ); ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - assert!((SIGMA - pt.data.std(basek, 0) * (k_ct as f64).exp2()) <= 0.2); + assert!((SIGMA - pt.data.std(base2k, 0) * (k_ct as f64).exp2()) <= 0.2); } } @@ -311,7 +304,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -337,15 +330,22 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 8; + let base2k: usize = 8; let k_ct: usize = 54; - let k_pk: usize = 54; - for rank in 1..3 { + for rank in 1_usize..3 { let n: usize = module.n(); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + + let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ct.into(), + rank: rank.into(), + }; + + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -353,19 +353,19 @@ where let mut source_xu: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) - | GLWECiphertext::encrypt_pk_scratch_space(module, basek, k_pk), + GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_infos) + | GLWECiphertext::decrypt_scratch_space(module, &glwe_infos) + | GLWECiphertext::encrypt_pk_scratch_space(module, &glwe_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(n, basek, k_pk, rank); + let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(&glwe_infos); pk.generate_from_sk(module, &sk_prepared, &mut source_xa, &mut source_xe); - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let pk_prepared: GLWEPublicKeyPrepared, B> = pk.prepare_alloc(module, scratch.borrow()); @@ -382,14 +382,13 @@ where pt_want.sub_inplace_ab(module, &pt_have); - let noise_have: f64 = pt_want.data.std(basek, 0).log2(); + let noise_have: f64 = pt_want.data.std(base2k, 0).log2(); let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64); assert!( noise_have <= noise_want + 0.2, - "{} {}", - noise_have, - noise_want + "{noise_have} <= {}", + noise_want + 0.2 ); } } diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs index d653b17..d373278 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, }, layouts::{Backend, Module, ScratchOwned, VecZnxDft}, oep::{ @@ -17,7 +17,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWETensorKey, GLWEPlaintext, GLWESecret, Infos, + Digits, GGLWETensorKey, GGLWETensorKeyLayout, GLWEPlaintext, GLWESecret, compressed::{Decompress, GGLWETensorKeyCompressed}, prepared::{GLWESecretPrepared, PrepareAlloc}, }, @@ -32,7 +32,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -64,14 +64,23 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 8; + let base2k: usize = 8; let k: usize = 54; - (1..3).for_each(|rank| { + for rank in 1_usize..3 { let n: usize = module.n(); - let rows: usize = k / basek; + let rows: usize = k / base2k; - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k, rows, 1, rank); + let tensor_key_infos = GGLWETensorKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + rows: rows.into(), + digits: Digits(1), + rank: rank.into(), + }; + + let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -79,12 +88,10 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKey::encrypt_sk_scratch_space( module, - basek, - tensor_key.k(), - rank, + &tensor_key_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&tensor_key_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -96,45 +103,44 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&tensor_key_infos); let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc_with(n.into(), 1_u32.into()); let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); - (0..rank).for_each(|i| { + for i in 0..rank { module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - }); + } - (0..rank).for_each(|i| { - (0..rank).for_each(|j| { + for i in 0..rank { + for j in 0..rank { module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); module.vec_znx_big_normalize( - basek, + base2k, &mut sk_ij.data.as_vec_znx_mut(), 0, + base2k, &sk_ij_big, 0, scratch.borrow(), ); - (0..tensor_key.rank_in()).for_each(|col_i| { - (0..tensor_key.rows()).for_each(|row_i| { - tensor_key - .at(i, j) - .at(row_i, col_i) - .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); + for row_i in 0..rows { + tensor_key + .at(i, j) + .at(row_i, 0) + .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); + module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, 0); - let std_pt: f64 = pt.data.std(basek, 0) * (k as f64).exp2(); - assert!((SIGMA - std_pt).abs() <= 0.5, "{} {}", SIGMA, std_pt); - }); - }); - }); - }); - }); + let std_pt: f64 = pt.data.std(base2k, 0) * (k as f64).exp2(); + assert!((SIGMA - std_pt).abs() <= 0.5, "{SIGMA} {std_pt}"); + } + } + } + } } pub fn test_gglwe_tensor_key_compressed_encrypt_sk(module: &Module) @@ -146,7 +152,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -178,26 +184,32 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek = 8; + let base2k = 8; let k = 54; - (1..3).for_each(|rank| { + for rank in 1_usize..3 { let n: usize = module.n(); - let rows: usize = k / basek; + let rows: usize = k / base2k; - let mut tensor_key_compressed: GGLWETensorKeyCompressed> = - GGLWETensorKeyCompressed::alloc(n, basek, k, rows, 1, rank); + let tensor_key_infos: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + rows: rows.into(), + digits: Digits(1), + rank: rank.into(), + }; + + let mut tensor_key_compressed: GGLWETensorKeyCompressed> = GGLWETensorKeyCompressed::alloc(&tensor_key_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKeyCompressed::encrypt_sk_scratch_space( module, - basek, - tensor_key_compressed.k(), - rank, + &tensor_key_infos, )); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&tensor_key_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -205,46 +217,45 @@ where tensor_key_compressed.encrypt_sk(module, &sk, seed_xa, &mut source_xe, scratch.borrow()); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k, rows, 1, rank); + let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(&tensor_key_infos); tensor_key.decompress(module, &tensor_key_compressed); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&tensor_key_infos); let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc_with(n.into(), 1_u32.into()); let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); - (0..rank).for_each(|i| { + for i in 0..rank { module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - }); + } - (0..rank).for_each(|i| { - (0..rank).for_each(|j| { + for i in 0..rank { + for j in 0..rank { module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); module.vec_znx_big_normalize( - basek, + base2k, &mut sk_ij.data.as_vec_znx_mut(), 0, + base2k, &sk_ij_big, 0, scratch.borrow(), ); - (0..tensor_key.rank_in()).for_each(|col_i| { - (0..tensor_key.rows()).for_each(|row_i| { - tensor_key - .at(i, j) - .at(row_i, col_i) - .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); + for row_i in 0..rows { + tensor_key + .at(i, j) + .at(row_i, 0) + .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); + module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, 0); - let std_pt: f64 = pt.data.std(basek, 0) * (k as f64).exp2(); - assert!((SIGMA - std_pt).abs() <= 0.5, "{} {}", SIGMA, std_pt); - }); - }); - }); - }); - }); + let std_pt: f64 = pt.data.std(base2k, 0) * (k as f64).exp2(); + assert!((SIGMA - std_pt).abs() <= 0.5, "{SIGMA} {std_pt}"); + } + } + } + } } diff --git a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs index 86f5c28..b2f8f91 100644 --- a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, - VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned, ZnxViewMut}, @@ -18,7 +18,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGSWCiphertext, GLWESecret, + GGLWESwitchingKey, GGLWESwitchingKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::noise_ggsw_product, @@ -34,7 +34,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -68,24 +68,51 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..3).for_each(|rank_in| { - (1..3).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; + let digits: usize = k_in.div_ceil(base2k); + for rank_in in 1_usize..3 { + for rank_out in 1_usize..3 { + for di in 1_usize..digits + 1 { + let k_ggsw: usize = k_in + base2k * di; let k_out: usize = k_in; // Better capture noise. let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * di); + let rows: usize = k_in.div_ceil(base2k * di); let digits_in: usize = 1; - let mut ct_gglwe_in: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_in, rows, digits_in, rank_in, rank_out); - let mut ct_gglwe_out: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_out, rows, digits_in, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, di, rank_out); + let gglwe_in_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rows: rows.into(), + digits: digits_in.into(), + rank_in: rank_in.into(), + rank_out: rank_out.into(), + }; + + let gglwe_out_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows.into(), + digits: digits_in.into(), + rank_in: rank_in.into(), + rank_out: rank_out.into(), + }; + + let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ggsw.into(), + rows: rows.into(), + digits: di.into(), + rank: rank_out.into(), + }; + + let mut ct_gglwe_in: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_in_infos); + let mut ct_gglwe_out: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_out_infos); + let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -94,9 +121,14 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_in, rank_in, rank_out) - | GGLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, di, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), + GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_in_infos) + | GGLWESwitchingKey::external_product_scratch_space( + module, + &gglwe_out_infos, + &gglwe_in_infos, + &ggsw_infos, + ) + | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_infos), ); let r: usize = 1; @@ -105,10 +137,10 @@ where let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -154,7 +186,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - basek * di, + base2k * di, var_xs, var_msg, var_a0_err, @@ -169,9 +201,9 @@ where ct_gglwe_out .key .assert_noise(module, &sk_out_prepared, &sk_in.data, max_noise + 0.5); - }); - }); - }); + } + } + } } #[allow(clippy::too_many_arguments)] @@ -184,7 +216,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -218,22 +250,40 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..3).for_each(|rank_in| { - (1..3).for_each(|rank_out| { - (1..digits).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; + let base2k: usize = 12; + let k_out: usize = 60; + let digits: usize = k_out.div_ceil(base2k); + for rank_in in 1_usize..3 { + for rank_out in 1_usize..3 { + for di in 1_usize..digits + 1 { + let k_ggsw: usize = k_out + base2k * di; let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * di); + let rows: usize = k_out.div_ceil(base2k * di); let digits_in: usize = 1; - let mut ct_gglwe: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_ct, rows, digits_in, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, di, rank_out); + let gglwe_out_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows.into(), + digits: digits_in.into(), + rank_in: rank_in.into(), + rank_out: rank_out.into(), + }; + + let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ggsw.into(), + rows: rows.into(), + digits: di.into(), + rank: rank_out.into(), + }; + + let mut ct_gglwe: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_out_infos); + let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -242,9 +292,9 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ct, rank_in, rank_out) - | GGLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, di, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), + GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_out_infos) + | GGLWESwitchingKey::external_product_inplace_scratch_space(module, &gglwe_out_infos, &ggsw_infos) + | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_infos), ); let r: usize = 1; @@ -253,10 +303,10 @@ where let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -302,7 +352,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - basek * di, + base2k * di, var_xs, var_msg, var_a0_err, @@ -310,14 +360,14 @@ where var_gct_err_lhs, var_gct_err_rhs, rank_out as f64, - k_ct, + k_out, k_ggsw, ); ct_gglwe .key .assert_noise(module, &sk_out_prepared, &sk_in.data, max_noise + 0.5); - }); - }); - }); + } + } + } } diff --git a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs index ceb95f9..bf5ceb5 100644 --- a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned, ZnxViewMut}, @@ -18,7 +18,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGSWCiphertext, GLWESecret, + GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::noise_ggsw_product, @@ -34,7 +34,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -70,73 +70,96 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; + let digits: usize = k_in.div_ceil(base2k); + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_apply: usize = k_in + base2k * di; let k_out: usize = k_in; // Better capture noise. let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * di); - let rows_in: usize = k_in.div_euclid(basek * di); + let rows: usize = k_in.div_ceil(base2k * di); + let rows_in: usize = k_in.div_euclid(base2k * di); let digits_in: usize = 1; - let mut ct_ggsw_lhs_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows_in, digits_in, rank); - let mut ct_ggsw_lhs_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows_in, digits_in, rank); - let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, di, rank); - let mut pt_ggsw_lhs: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_ggsw_rhs: ScalarZnx> = ScalarZnx::alloc(n, 1); + let ggsw_in_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rows: rows_in.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let ggsw_out_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows_in.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let ggsw_apply_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_apply.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut ggsw_in: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_in_infos); + let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); + let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); + let mut pt_in: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_apply: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + pt_in.fill_ternary_prob(0, 0.5, &mut source_xs); let k: usize = 1; - pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + pt_apply.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, di, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) + | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_in_infos) + | GGSWCiphertext::external_product_scratch_space(module, &ggsw_out_infos, &ggsw_in_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - ct_ggsw_rhs.encrypt_sk( + ggsw_apply.encrypt_sk( module, - &pt_ggsw_rhs, + &pt_apply, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow(), ); - ct_ggsw_lhs_in.encrypt_sk( + ggsw_in.encrypt_sk( module, - &pt_ggsw_lhs, + &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow(), ); - let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ct_ggsw_rhs.prepare_alloc(module, scratch.borrow()); + let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); - ct_ggsw_lhs_out.external_product(module, &ct_ggsw_lhs_in, &ct_rhs_prepared, scratch.borrow()); + ggsw_out.external_product(module, &ggsw_in, &ct_rhs_prepared, scratch.borrow()); - module.vec_znx_rotate_inplace( - k as i64, - &mut pt_ggsw_lhs.as_vec_znx_mut(), - 0, - scratch.borrow(), - ); + module.vec_znx_rotate_inplace(k as i64, &mut pt_in.as_vec_znx_mut(), 0, scratch.borrow()); let var_gct_err_lhs: f64 = SIGMA * SIGMA; let var_gct_err_rhs: f64 = 0f64; @@ -148,7 +171,7 @@ where let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( n as f64, - basek * di, + base2k * di, 0.5, var_msg, var_a0_err, @@ -157,13 +180,13 @@ where var_gct_err_rhs, rank as f64, k_in, - k_ggsw, + k_apply, ) + 0.5 }; - ct_ggsw_lhs_out.assert_noise(module, &sk_prepared, &pt_ggsw_lhs, max_noise); - }); - }); + ggsw_out.assert_noise(module, &sk_prepared, &pt_in, max_noise); + } + } } #[allow(clippy::too_many_arguments)] @@ -176,7 +199,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -212,71 +235,85 @@ where + VecZnxBigAllocBytesImpl + TakeSvpPPolImpl, { - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..3).for_each(|rank| { - (1..digits).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; + let base2k: usize = 12; + let k_out: usize = 60; + let digits: usize = k_out.div_ceil(base2k); + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_apply: usize = k_out + base2k * di; let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(di * basek); - let rows_in: usize = k_ct.div_euclid(basek * di); + let rows: usize = k_out.div_ceil(di * base2k); + let rows_in: usize = k_out.div_euclid(base2k * di); let digits_in: usize = 1; - let mut ct_ggsw_lhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows_in, digits_in, rank); - let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, di, rank); + let ggsw_out_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows_in.into(), + digits: digits_in.into(), + rank: rank.into(), + }; - let mut pt_ggsw_lhs: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_ggsw_rhs: ScalarZnx> = ScalarZnx::alloc(n, 1); + let ggsw_apply_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_apply.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); + let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); + + let mut pt_in: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_apply: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + pt_in.fill_ternary_prob(0, 0.5, &mut source_xs); let k: usize = 1; - pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + pt_apply.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, di, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) + | GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_out_infos) + | GGSWCiphertext::external_product_inplace_scratch_space(module, &ggsw_out_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - ct_ggsw_rhs.encrypt_sk( + ggsw_apply.encrypt_sk( module, - &pt_ggsw_rhs, + &pt_apply, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow(), ); - ct_ggsw_lhs.encrypt_sk( + ggsw_out.encrypt_sk( module, - &pt_ggsw_lhs, + &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow(), ); - let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ct_ggsw_rhs.prepare_alloc(module, scratch.borrow()); + let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); - ct_ggsw_lhs.external_product_inplace(module, &ct_rhs_prepared, scratch.borrow()); + ggsw_out.external_product_inplace(module, &ct_rhs_prepared, scratch.borrow()); - module.vec_znx_rotate_inplace( - k as i64, - &mut pt_ggsw_lhs.as_vec_znx_mut(), - 0, - scratch.borrow(), - ); + module.vec_znx_rotate_inplace(k as i64, &mut pt_in.as_vec_znx_mut(), 0, scratch.borrow()); let var_gct_err_lhs: f64 = SIGMA * SIGMA; let var_gct_err_rhs: f64 = 0f64; @@ -288,7 +325,7 @@ where let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( n as f64, - basek * di, + base2k * di, 0.5, var_msg, var_a0_err, @@ -296,12 +333,12 @@ where var_gct_err_lhs, var_gct_err_rhs, rank as f64, - k_ct, - k_ggsw, + k_out, + k_apply, ) + 0.5 }; - ct_ggsw_lhs.assert_noise(module, &sk_prepared, &pt_ggsw_lhs, max_noise); - }); - }); + ggsw_out.assert_noise(module, &sk_prepared, &pt_in, max_noise); + } + } } diff --git a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs index 93842be..62a3d7c 100644 --- a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs @@ -3,7 +3,7 @@ use poulpy_hal::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScalarZnx, ScratchOwned, ZnxViewMut}, @@ -17,7 +17,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, + GGSWCiphertext, GGSWCiphertextLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::noise_ggsw_product, @@ -32,7 +32,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -62,64 +62,79 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; + let digits: usize = k_in.div_ceil(base2k); + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_ggsw: usize = k_in + base2k * di; let k_out: usize = k_ggsw; // Better capture noise let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); + let rows: usize = k_in.div_ceil(base2k * digits); - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank); - let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); + let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rank: rank.into(), + }; + + let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rank: rank.into(), + }; + + let ggsw_apply_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ggsw.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); + let mut glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_infos); + let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_in_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); pt_want.data.at_mut(0, 0)[1] = 1; let k: usize = 1; - pt_rgsw.raw_mut()[k] = 1; // X^{k} + pt_ggsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe_in.k()) - | GLWECiphertext::external_product_scratch_space( - module, - basek, - ct_glwe_out.k(), - ct_glwe_in.k(), - ct_ggsw.k(), - digits, - rank, - ), + GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) + | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_in_infos) + | GLWECiphertext::external_product_scratch_space(module, &glwe_out_infos, &glwe_in_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - ct_ggsw.encrypt_sk( + ggsw_apply.encrypt_sk( module, - &pt_rgsw, + &pt_ggsw, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow(), ); - ct_glwe_in.encrypt_sk( + glwe_in.encrypt_sk( module, &pt_want, &sk_prepared, @@ -128,9 +143,9 @@ where scratch.borrow(), ); - let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ct_ggsw.prepare_alloc(module, scratch.borrow()); + let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); - ct_glwe_out.external_product(module, &ct_glwe_in, &ct_ggsw_prepared, scratch.borrow()); + glwe_out.external_product(module, &glwe_in, &ct_ggsw_prepared, scratch.borrow()); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0, scratch.borrow()); @@ -143,7 +158,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - basek * digits, + base2k * digits, 0.5, var_msg, var_a0_err, @@ -155,9 +170,9 @@ where k_ggsw, ); - ct_glwe_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); - }); - }); + glwe_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); + } + } } #[allow(clippy::too_many_arguments)] @@ -169,7 +184,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -199,61 +214,70 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; + let base2k: usize = 12; + let k_out: usize = 60; + let digits: usize = k_out.div_ceil(base2k); + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_ggsw: usize = k_out + base2k * di; let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * digits); + let rows: usize = k_out.div_ceil(base2k * digits); - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rank: rank.into(), + }; + + let ggsw_apply_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ggsw.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let mut ggsw_apply: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_apply_infos); + let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); pt_want.data.at_mut(0, 0)[1] = 1; let k: usize = 1; - pt_rgsw.raw_mut()[k] = 1; // X^{k} + pt_ggsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space( - module, - basek, - ct_glwe.k(), - ct_ggsw.k(), - digits, - rank, - ), + GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_apply_infos) + | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) + | GLWECiphertext::external_product_inplace_scratch_space(module, &glwe_out_infos, &ggsw_apply_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - ct_ggsw.encrypt_sk( + ggsw_apply.encrypt_sk( module, - &pt_rgsw, + &pt_ggsw, &sk_prepared, &mut source_xa, &mut source_xe, scratch.borrow(), ); - ct_glwe.encrypt_sk( + glwe_out.encrypt_sk( module, &pt_want, &sk_prepared, @@ -262,9 +286,9 @@ where scratch.borrow(), ); - let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ct_ggsw.prepare_alloc(module, scratch.borrow()); + let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ggsw_apply.prepare_alloc(module, scratch.borrow()); - ct_glwe.external_product_inplace(module, &ct_ggsw_prepared, scratch.borrow()); + glwe_out.external_product_inplace(module, &ct_ggsw_prepared, scratch.borrow()); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0, scratch.borrow()); @@ -277,7 +301,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - basek * digits, + base2k * digits, 0.5, var_msg, var_a0_err, @@ -285,11 +309,11 @@ where var_gct_err_lhs, var_gct_err_rhs, rank as f64, - k_ct, + k_out, k_ggsw, ); - ct_glwe.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); - }); - }); + glwe_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); + } + } } diff --git a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs index a0b35bc..637c5a6 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VecZnxSubInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, @@ -18,13 +18,12 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GLWESecret, + GGLWESwitchingKey, GGLWESwitchingKeyLayout, GLWESecret, prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::log2_std_noise_gglwe_product, }; -#[allow(clippy::too_many_arguments)] pub fn test_gglwe_switching_key_keyswitch(module: &Module) where Module: VecZnxDftAllocBytes @@ -33,7 +32,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -65,77 +64,84 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); + let digits: usize = k_in.div_ceil(base2k); - (1..3).for_each(|rank_in_s0s1| { - (1..3).for_each(|rank_out_s0s1| { - (1..3).for_each(|rank_out_s1s2| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; + for rank_in_s0s1 in 1_usize..3 { + for rank_out_s0s1 in 1_usize..3 { + for rank_out_s1s2 in 1_usize..3 { + for di in 1_usize..digits + 1 { + let k_ksk: usize = k_in + base2k * di; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let rows: usize = k_in / basek; - let rows_apply: usize = k_in.div_ceil(basek * di); + let rows: usize = k_in / base2k; + let rows_apply: usize = k_in.div_ceil(base2k * di); let digits_in: usize = 1; - let mut ct_gglwe_s0s1: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_in, rows, digits_in, rank_in_s0s1, rank_out_s0s1); - let mut ct_gglwe_s1s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc( - n, - basek, - k_ksk, - rows_apply, - di, - rank_out_s0s1, - rank_out_s1s2, - ); - let mut ct_gglwe_s0s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc( - n, - basek, - k_out, - rows, - digits_in, - rank_in_s0s1, - rank_out_s1s2, - ); + let gglwe_s0s1_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rows: rows.into(), + digits: digits_in.into(), + rank_in: rank_in_s0s1.into(), + rank_out: rank_out_s0s1.into(), + }; + + let gglwe_s1s2_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows_apply.into(), + digits: di.into(), + rank_in: rank_out_s0s1.into(), + rank_out: rank_out_s1s2.into(), + }; + + let gglwe_s0s2_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows_apply.into(), + digits: digits_in.into(), + rank_in: rank_in_s0s1.into(), + rank_out: rank_out_s1s2.into(), + }; + + let mut gglwe_s0s1: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s0s1_infos); + let mut gglwe_s1s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s1s2_infos); + let mut gglwe_s0s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s0s2_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( - module, - basek, - k_ksk, - rank_in_s0s1 | rank_out_s0s1, - rank_out_s0s1 | rank_out_s1s2, - )); + let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( + GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s1_infos) + | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s1s2_infos) + | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s2_infos), + ); let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_scratch_space( module, - basek, - k_out, - k_in, - k_ksk, - di, - ct_gglwe_s1s2.rank_in(), - ct_gglwe_s1s2.rank_out(), + &gglwe_s0s1_infos, + &gglwe_s0s2_infos, + &gglwe_s1s2_infos, )); - let mut sk0: GLWESecret> = GLWESecret::alloc(n, rank_in_s0s1); + let mut sk0: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in_s0s1.into()); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1: GLWESecret> = GLWESecret::alloc(n, rank_out_s0s1); + let mut sk1: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out_s0s1.into()); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2: GLWESecret> = GLWESecret::alloc(n, rank_out_s1s2); + let mut sk2: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out_s1s2.into()); sk2.fill_ternary_prob(0.5, &mut source_xs); let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.encrypt_sk( + gglwe_s0s1.encrypt_sk( module, &sk0, &sk1, @@ -145,7 +151,7 @@ where ); // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.encrypt_sk( + gglwe_s1s2.encrypt_sk( module, &sk1, &sk2, @@ -154,20 +160,20 @@ where scratch_enc.borrow(), ); - let ct_gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = - ct_gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); + let gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = + gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - ct_gglwe_s0s2.keyswitch( + gglwe_s0s2.keyswitch( module, - &ct_gglwe_s0s1, - &ct_gglwe_s1s2_prepared, + &gglwe_s0s1, + &gglwe_s1s2_prepared, scratch_apply.borrow(), ); let max_noise: f64 = log2_std_noise_gglwe_product( n as f64, - basek * di, + base2k * di, 0.5, 0.5, 0f64, @@ -178,13 +184,13 @@ where k_ksk, ); - ct_gglwe_s0s2 + gglwe_s0s2 .key .assert_noise(module, &sk2_prepared, &sk0.data, max_noise + 0.5); - }); - }); - }); - }); + } + } + } + } } #[allow(clippy::too_many_arguments)] @@ -196,7 +202,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -228,52 +234,69 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..3).for_each(|rank_in| { - (1..3).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; + let base2k: usize = 12; + let k_out: usize = 60; + let digits: usize = k_out.div_ceil(base2k); + for rank_in in 1_usize..3 { + for rank_out in 1_usize..3 { + for di in 1_usize..digits + 1 { + let k_ksk: usize = k_out + base2k * di; let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * di); + let rows: usize = k_out.div_ceil(base2k * di); let digits_in: usize = 1; - let mut ct_gglwe_s0s1: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_ct, rows, digits_in, rank_in, rank_out); - let mut ct_gglwe_s1s2: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank_out, rank_out); + let gglwe_s0s1_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows.into(), + digits: digits_in.into(), + rank_in: rank_in.into(), + rank_out: rank_out.into(), + }; + + let gglwe_s1s2_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank_in: rank_out.into(), + rank_out: rank_out.into(), + }; + + let mut gglwe_s0s1: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s0s1_infos); + let mut gglwe_s1s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&gglwe_s1s2_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( - module, - basek, - k_ksk, - rank_in | rank_out, - rank_out, - )); + let mut scratch_enc: ScratchOwned = ScratchOwned::alloc( + GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s0s1_infos) + | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &gglwe_s1s2_infos), + ); let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_inplace_scratch_space( - module, basek, k_ct, k_ksk, di, rank_out, + module, + &gglwe_s0s1_infos, + &gglwe_s1s2_infos, )); let var_xs: f64 = 0.5; - let mut sk0: GLWESecret> = GLWESecret::alloc(n, rank_in); + let mut sk0: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); sk0.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk1: GLWESecret> = GLWESecret::alloc(n, rank_out); + let mut sk1: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); sk1.fill_ternary_prob(var_xs, &mut source_xs); - let mut sk2: GLWESecret> = GLWESecret::alloc(n, rank_out); + let mut sk2: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); sk2.fill_ternary_prob(var_xs, &mut source_xs); let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.encrypt_sk( + gglwe_s0s1.encrypt_sk( module, &sk0, &sk1, @@ -283,7 +306,7 @@ where ); // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.encrypt_sk( + gglwe_s1s2.encrypt_sk( module, &sk1, &sk2, @@ -292,31 +315,31 @@ where scratch_enc.borrow(), ); - let ct_gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = - ct_gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); + let gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = + gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - ct_gglwe_s0s1.keyswitch_inplace(module, &ct_gglwe_s1s2_prepared, scratch_apply.borrow()); + gglwe_s0s1.keyswitch_inplace(module, &gglwe_s1s2_prepared, scratch_apply.borrow()); - let ct_gglwe_s0s2: GGLWESwitchingKey> = ct_gglwe_s0s1; + let gglwe_s0s2: GGLWESwitchingKey> = gglwe_s0s1; let max_noise: f64 = log2_std_noise_gglwe_product( n as f64, - basek * di, + base2k * di, var_xs, var_xs, 0f64, SIGMA * SIGMA, 0f64, rank_out as f64, - k_ct, + k_out, k_ksk, ); - ct_gglwe_s0s2 + gglwe_s0s2 .key .assert_noise(module, &sk2_prepared, &sk0.data, max_noise + 0.5); - }); - }); - }); + } + } + } } diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs index 9b82ba9..21be8e7 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScalarZnx, ScratchOwned}, @@ -18,7 +18,8 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GLWESecret, + GGLWESwitchingKey, GGLWESwitchingKeyLayout, GGLWETensorKey, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, + GLWESecret, prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::noise_ggsw_keyswitch, @@ -33,7 +34,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -70,24 +71,61 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 54; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; + let digits: usize = k_in.div_ceil(base2k); + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_ksk: usize = k_in + base2k * di; let k_tsk: usize = k_ksk; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let rows: usize = k_in.div_ceil(di * basek); + let rows: usize = k_in.div_ceil(di * base2k); let digits_in: usize = 1; - let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows, digits_in, rank); - let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows, digits_in, rank); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_ksk, rows, di, rank); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank, rank); + let ggsw_in_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rows: rows.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let ggsw_out_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let tsk_infos: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_tsk.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let ksk_apply_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank_in: rank.into(), + rank_out: rank.into(), + }; + + let mut ggsw_in: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_in_infos); + let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); + let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(&tsk_infos); + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -95,19 +133,25 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, di, k_tsk, di, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_in_infos) + | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &ksk_apply_infos) + | GGLWETensorKey::encrypt_sk_scratch_space(module, &tsk_infos) + | GGSWCiphertext::keyswitch_scratch_space( + module, + &ggsw_out_infos, + &ggsw_in_infos, + &ksk_apply_infos, + &tsk_infos, + ), ); let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -129,7 +173,7 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - ct_in.encrypt_sk( + ggsw_in.encrypt_sk( module, &pt_scalar, &sk_in_dft, @@ -141,9 +185,9 @@ where let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); - ct_out.keyswitch( + ggsw_out.keyswitch( module, - &ct_in, + &ggsw_in, &ksk_prepared, &tsk_prepared, scratch.borrow(), @@ -152,7 +196,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - basek * di, + base2k * di, col_j, var_xs, 0f64, @@ -165,9 +209,9 @@ where ) + 0.5 }; - ct_out.assert_noise(module, &sk_out_prepared, &pt_scalar, max_noise); - }); - }); + ggsw_out.assert_noise(module, &sk_out_prepared, &pt_scalar, max_noise); + } + } } #[allow(clippy::too_many_arguments)] @@ -179,7 +223,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -216,22 +260,50 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; - let k_ct: usize = 54; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; + let base2k: usize = 12; + let k_out: usize = 54; + let digits: usize = k_out.div_ceil(base2k); + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_ksk: usize = k_out + base2k * di; let k_tsk: usize = k_ksk; let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(di * basek); + let rows: usize = k_out.div_ceil(di * base2k); let digits_in: usize = 1; - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows, digits_in, rank); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_tsk, rows, di, rank); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank, rank); + let ggsw_out_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rows: rows.into(), + digits: digits_in.into(), + rank: rank.into(), + }; + + let tsk_infos: GGLWETensorKeyLayout = GGLWETensorKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_tsk.into(), + rows: rows.into(), + digits: di.into(), + rank: rank.into(), + }; + + let ksk_apply_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank_in: rank.into(), + rank_out: rank.into(), + }; + + let mut ggsw_out: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_out_infos); + let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(&tsk_infos); + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -239,19 +311,19 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::keyswitch_inplace_scratch_space(module, basek, k_ct, k_ksk, di, k_tsk, di, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, &ggsw_out_infos) + | GGLWESwitchingKey::encrypt_sk_scratch_space(module, &ksk_apply_infos) + | GGLWETensorKey::encrypt_sk_scratch_space(module, &tsk_infos) + | GGSWCiphertext::keyswitch_inplace_scratch_space(module, &ggsw_out_infos, &ksk_apply_infos, &tsk_infos), ); let var_xs: f64 = 0.5; - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk_in.fill_ternary_prob(var_xs, &mut source_xs); let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk_out.fill_ternary_prob(var_xs, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -273,7 +345,7 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - ct.encrypt_sk( + ggsw_out.encrypt_sk( module, &pt_scalar, &sk_in_dft, @@ -285,25 +357,25 @@ where let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); - ct.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow()); + ggsw_out.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow()); let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - basek * di, + base2k * di, col_j, var_xs, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_ct, + k_out, k_ksk, k_tsk, ) + 0.5 }; - ct.assert_noise(module, &sk_out_prepared, &pt_scalar, max_noise); - }); - }); + ggsw_out.assert_noise(module, &sk_out_prepared, &pt_scalar, max_noise); + } + } } diff --git a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs index 456828b..ca8c914 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs @@ -4,7 +4,7 @@ use poulpy_hal::{ VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, + VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, @@ -18,7 +18,7 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWESwitchingKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, + GGLWESwitchingKey, GGLWESwitchingKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::log2_std_noise_gglwe_product, @@ -33,7 +33,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -64,51 +64,65 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; + let base2k: usize = 12; let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); + let digits: usize = k_in.div_ceil(base2k); - (1..3).for_each(|rank_in| { - (1..3).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; + for rank_in in 1_usize..3 { + for rank_out in 1_usize..3 { + for di in 1_usize..digits + 1 { + let k_ksk: usize = k_in + base2k * di; let k_out: usize = k_ksk; // better capture noise let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); + let rows: usize = k_in.div_ceil(base2k * digits); - let mut ksk: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank_in); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); + let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_in.into(), + rank: rank_in.into(), + }; + + let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rank: rank_out.into(), + }; + + let key_apply: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank_in: rank_in.into(), + rank_out: rank_out.into(), + }; + + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&key_apply); + let mut glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&glwe_in_infos); + let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_in_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank_in, rank_out) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) - | GLWECiphertext::keyswitch_scratch_space( - module, - basek, - ct_out.k(), - ct_in.k(), - ksk.k(), - digits, - rank_in, - rank_out, - ), + GGLWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply) + | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_in_infos) + | GLWECiphertext::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, &key_apply), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_in.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank_out.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); @@ -121,7 +135,7 @@ where scratch.borrow(), ); - ct_in.encrypt_sk( + glwe_in.encrypt_sk( module, &pt_want, &sk_in_prepared, @@ -132,11 +146,11 @@ where let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); - ct_out.keyswitch(module, &ct_in, &ksk_prepared, scratch.borrow()); + glwe_out.keyswitch(module, &glwe_in, &ksk_prepared, scratch.borrow()); let max_noise: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek * digits, + base2k * digits, 0.5, 0.5, 0f64, @@ -147,10 +161,10 @@ where k_ksk, ); - ct_out.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); - }) - }); - }); + glwe_out.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); + } + } + } } pub fn test_glwe_keyswitch_inplace(module: &Module) @@ -161,7 +175,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -192,42 +206,59 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 12; - let k_ct: usize = 45; - let digits: usize = k_ct.div_ceil(basek); + let base2k: usize = 12; + let k_out: usize = 45; + let digits: usize = k_out.div_ceil(base2k); - (1..3).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; + for rank in 1_usize..3 { + for di in 1..digits + 1 { + let k_ksk: usize = k_out + base2k * di; let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * digits); + let rows: usize = k_out.div_ceil(base2k * digits); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_out.into(), + rank: rank.into(), + }; + + let key_apply_infos: GGLWESwitchingKeyLayout = GGLWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + digits: di.into(), + rank_in: rank.into(), + rank_out: rank.into(), + }; + + let mut key_apply: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(&key_apply_infos); + let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank, rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, ct_glwe.k(), ksk.k(), digits, rank), + GGLWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply_infos) + | GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) + | GLWECiphertext::keyswitch_inplace_scratch_space(module, &glwe_out_infos, &key_apply_infos), ); - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_in: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk_in.fill_ternary_prob(0.5, &mut source_xs); let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_out: GLWESecret> = GLWESecret::alloc_with(n.into(), rank.into()); sk_out.fill_ternary_prob(0.5, &mut source_xs); let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - ksk.encrypt_sk( + key_apply.encrypt_sk( module, &sk_in, &sk_out, @@ -236,7 +267,7 @@ where scratch.borrow(), ); - ct_glwe.encrypt_sk( + glwe_out.encrypt_sk( module, &pt_want, &sk_in_prepared, @@ -245,24 +276,24 @@ where scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let ksk_prepared: GGLWESwitchingKeyPrepared, B> = key_apply.prepare_alloc(module, scratch.borrow()); - ct_glwe.keyswitch_inplace(module, &ksk_prepared, scratch.borrow()); + glwe_out.keyswitch_inplace(module, &ksk_prepared, scratch.borrow()); let max_noise: f64 = log2_std_noise_gglwe_product( module.n() as f64, - basek * digits, + base2k * digits, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_ct, + k_out, k_ksk, ); - ct_glwe.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); - }); - }); + glwe_out.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); + } + } } diff --git a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs index 9432ee8..0badc86 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs @@ -2,9 +2,9 @@ use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, }, layouts::{Backend, Module, ScratchOwned, ZnxView}, @@ -16,7 +16,7 @@ use poulpy_hal::{ }; use crate::layouts::{ - Infos, LWECiphertext, LWEPlaintext, LWESecret, LWESwitchingKey, + LWECiphertext, LWECiphertextLayout, LWEPlaintext, LWESecret, LWESwitchingKey, LWESwitchingKeyLayout, prepared::{LWESwitchingKeyPrepared, PrepareAlloc}, }; @@ -28,7 +28,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -52,7 +52,8 @@ where + VecZnxAutomorphismInplace + ZnNormalizeInplace + ZnFillUniform - + ZnAddNormal, + + ZnAddNormal + + VecZnxCopy, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl @@ -64,37 +65,56 @@ where + TakeVecZnxImpl, { let n: usize = module.n(); - let basek: usize = 17; + let base2k: usize = 17; let n_lwe_in: usize = 22; let n_lwe_out: usize = 30; - let k_lwe_ct: usize = 2 * basek; + let k_lwe_ct: usize = 2 * base2k; let k_lwe_pt: usize = 8; - let k_ksk: usize = k_lwe_ct + basek; + let k_ksk: usize = k_lwe_ct + base2k; + let rows: usize = k_lwe_ct.div_ceil(base2k); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); + let key_apply_infos: LWESwitchingKeyLayout = LWESwitchingKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rows: rows.into(), + }; + + let lwe_in_infos: LWECiphertextLayout = LWECiphertextLayout { + n: n_lwe_in.into(), + base2k: base2k.into(), + k: k_lwe_ct.into(), + }; + + let lwe_out_infos: LWECiphertextLayout = LWECiphertextLayout { + n: n_lwe_out.into(), + k: k_lwe_ct.into(), + base2k: base2k.into(), + }; + let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk) - | LWECiphertext::keyswitch_scratch_space(module, basek, k_lwe_ct, k_lwe_ct, k_ksk), + LWESwitchingKey::encrypt_sk_scratch_space(module, &key_apply_infos) + | LWECiphertext::keyswitch_scratch_space(module, &lwe_out_infos, &lwe_in_infos, &key_apply_infos), ); - let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in); + let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in.into()); sk_lwe_in.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_lwe_out: LWESecret> = LWESecret::alloc(n_lwe_out); + let mut sk_lwe_out: LWESecret> = LWESecret::alloc(n_lwe_out.into()); sk_lwe_out.fill_ternary_prob(0.5, &mut source_xs); let data: i64 = 17; - let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); + let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + lwe_pt_in.encode_i64(data, k_lwe_pt.into()); - lwe_pt_in.encode_i64(data, k_lwe_pt); - - let mut lwe_ct_in: LWECiphertext> = LWECiphertext::alloc(n_lwe_in, basek, k_lwe_ct); + let mut lwe_ct_in: LWECiphertext> = LWECiphertext::alloc(&lwe_in_infos); lwe_ct_in.encrypt_sk( module, &lwe_pt_in, @@ -103,7 +123,7 @@ where &mut source_xe, ); - let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc(n, basek, k_ksk, lwe_ct_in.size()); + let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc(&key_apply_infos); ksk.encrypt_sk( module, @@ -114,13 +134,13 @@ where scratch.borrow(), ); - let mut lwe_ct_out: LWECiphertext> = LWECiphertext::alloc(n_lwe_out, basek, k_lwe_ct); + let mut lwe_ct_out: LWECiphertext> = LWECiphertext::alloc(&lwe_out_infos); let ksk_prepared: LWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); lwe_ct_out.keyswitch(module, &lwe_ct_in, &ksk_prepared, scratch.borrow()); - let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_ct); + let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc(&lwe_out_infos); lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out); assert_eq!(lwe_pt_in.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); diff --git a/poulpy-core/src/tests/test_suite/packing.rs b/poulpy-core/src/tests/test_suite/packing.rs index f1cbe15..248f95d 100644 --- a/poulpy-core/src/tests/test_suite/packing.rs +++ b/poulpy-core/src/tests/test_suite/packing.rs @@ -5,9 +5,9 @@ use poulpy_hal::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, + VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, @@ -21,7 +21,7 @@ use poulpy_hal::{ use crate::{ GLWEOperations, GLWEPacker, layouts::{ - GGLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, + GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }, }; @@ -31,7 +31,7 @@ where Module: VecZnxDftAllocBytes + VecZnxAutomorphism + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallBInplace + + VecZnxBigSubSmallNegateInplace + VecZnxNegateInplace + VecZnxRshInplace + VecZnxRotateInplace @@ -41,7 +41,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -79,37 +79,53 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let n: usize = module.n(); - let basek: usize = 18; + let base2k: usize = 18; let k_ct: usize = 36; let pt_k: usize = 18; let rank: usize = 3; let digits: usize = 1; - let k_ksk: usize = k_ct + basek * digits; + let k_ksk: usize = k_ct + base2k * digits; - let rows: usize = k_ct.div_ceil(basek * digits); + let rows: usize = k_ct.div_ceil(base2k * digits); + + let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ct.into(), + rank: rank.into(), + }; + + let key_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + rank: rank.into(), + digits: digits.into(), + rows: rows.into(), + }; let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GLWEPacker::scratch_space(module, basek, k_ct, k_ksk, digits, rank), + GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &key_infos) + | GLWEPacker::scratch_space(module, &glwe_out_infos, &key_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_out_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); let mut data: Vec = vec![0i64; n]; data.iter_mut().enumerate().for_each(|(i, x)| { *x = i as i64; }); - pt.encode_vec_i64(&data, pt_k); + pt.encode_vec_i64(&data, pt_k.into()); let gal_els: Vec = GLWEPacker::galois_elements(module); let mut auto_keys: HashMap, B>> = HashMap::new(); - let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); + let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&key_infos); gal_els.iter().for_each(|gal_el| { tmp.encrypt_sk( module, @@ -125,9 +141,9 @@ where let log_batch: usize = 0; - let mut packer: GLWEPacker = GLWEPacker::new(n, log_batch, basek, k_ct, rank); + let mut packer: GLWEPacker = GLWEPacker::new(&glwe_out_infos, log_batch); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); ct.encrypt_sk( module, @@ -164,10 +180,10 @@ where } }); - let mut res = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut res: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); packer.flush(module, &mut res); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); let mut data: Vec = vec![0i64; n]; data.iter_mut().enumerate().for_each(|(i, x)| { if i.is_multiple_of(5) { @@ -175,7 +191,7 @@ where } }); - pt_want.encode_vec_i64(&data, pt_k); + pt_want.encode_vec_i64(&data, pt_k.into()); res.decrypt(module, &mut pt, &sk_dft, scratch.borrow()); @@ -184,9 +200,8 @@ where let noise_have: f64 = pt.std().log2(); assert!( - noise_have < -((k_ct - basek) as f64), - "noise: {}", - noise_have + noise_have < -((k_ct - base2k) as f64), + "noise: {noise_have}" ); } diff --git a/poulpy-core/src/tests/test_suite/trace.rs b/poulpy-core/src/tests/test_suite/trace.rs index b369626..20dc5e7 100644 --- a/poulpy-core/src/tests/test_suite/trace.rs +++ b/poulpy-core/src/tests/test_suite/trace.rs @@ -5,9 +5,9 @@ use poulpy_hal::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, - VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, + VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned, ZnxView, ZnxViewMut}, @@ -21,7 +21,8 @@ use poulpy_hal::{ use crate::{ encryption::SIGMA, layouts::{ - GGLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, + GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, + LWEInfos, prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }, noise::var_noise_gglwe_product, @@ -32,7 +33,7 @@ where Module: VecZnxDftAllocBytes + VecZnxAutomorphism + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallBInplace + + VecZnxBigSubSmallNegateInplace + VecZnxRshInplace + VecZnxRotateInplace + VecZnxBigNormalize @@ -40,7 +41,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -72,32 +73,48 @@ where + TakeScalarZnxImpl + TakeVecZnxImpl, { - let basek: usize = 8; + let base2k: usize = 8; let k: usize = 54; - (1..3).for_each(|rank| { + for rank in 1_usize..3 { let n: usize = module.n(); - let k_autokey: usize = k + basek; + let k_autokey: usize = k + base2k; let digits: usize = 1; - let rows: usize = k.div_ceil(basek * digits); + let rows: usize = k.div_ceil(base2k * digits); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); + let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + rank: rank.into(), + }; + + let key_infos: GGLWEAutomorphismKeyLayout = GGLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k_autokey.into(), + rank: rank.into(), + digits: digits.into(), + rows: rows.into(), + }; + + let mut glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&glwe_out_infos); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_autokey, rank) - | GLWECiphertext::trace_inplace_scratch_space(module, basek, ct.k(), k_autokey, digits, rank), + GLWECiphertext::encrypt_sk_scratch_space(module, &glwe_out_infos) + | GLWECiphertext::decrypt_scratch_space(module, &glwe_out_infos) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, &key_infos) + | GLWECiphertext::trace_inplace_scratch_space(module, &glwe_out_infos, &key_infos), ); - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk: GLWESecret> = GLWESecret::alloc(&glwe_out_infos); sk.fill_ternary_prob(0.5, &mut source_xs); let sk_dft: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); @@ -107,9 +124,9 @@ where .iter_mut() .for_each(|x| *x = source_xa.next_i64() & 0xFF); - module.vec_znx_fill_uniform(basek, &mut pt_have.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k, &mut pt_have.data, 0, &mut source_xa); - ct.encrypt_sk( + glwe_out.encrypt_sk( module, &pt_have, &sk_dft, @@ -120,7 +137,7 @@ where let mut auto_keys: HashMap, B>> = HashMap::new(); let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); - let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_autokey, rows, digits, rank); + let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&key_infos); gal_els.iter().for_each(|gal_el| { tmp.encrypt_sk( module, @@ -134,21 +151,21 @@ where auto_keys.insert(*gal_el, atk_prepared); }); - ct.trace_inplace(module, 0, 5, &auto_keys, scratch.borrow()); - ct.trace_inplace(module, 5, module.log_n(), &auto_keys, scratch.borrow()); + glwe_out.trace_inplace(module, 0, 5, &auto_keys, scratch.borrow()); + glwe_out.trace_inplace(module, 5, module.log_n(), &auto_keys, scratch.borrow()); (0..pt_want.size()).for_each(|i| pt_want.data.at_mut(0, i)[0] = pt_have.data.at(0, i)[0]); - ct.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); + glwe_out.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); - module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); - module.vec_znx_normalize_inplace(basek, &mut pt_want.data, 0, scratch.borrow()); + module.vec_znx_sub_inplace(&mut pt_want.data, 0, &pt_have.data, 0); + module.vec_znx_normalize_inplace(base2k, &mut pt_want.data, 0, scratch.borrow()); let noise_have: f64 = pt_want.std().log2(); let mut noise_want: f64 = var_noise_gglwe_product( n as f64, - basek, + base2k, 0.5, 0.5, 1.0 / 12.0, @@ -164,9 +181,7 @@ where assert!( (noise_have - noise_want).abs() < 1.0, - "{} > {}", - noise_have, - noise_want + "{noise_have} > {noise_want}" ); - }); + } } diff --git a/poulpy-core/src/utils.rs b/poulpy-core/src/utils.rs index 6b8e8f5..9f9e4a0 100644 --- a/poulpy-core/src/utils.rs +++ b/poulpy-core/src/utils.rs @@ -1,52 +1,52 @@ -use crate::layouts::{GLWEPlaintext, Infos, LWEPlaintext}; +use crate::layouts::{GLWEPlaintext, LWEInfos, LWEPlaintext, TorusPrecision}; use poulpy_hal::layouts::{DataMut, DataRef}; use rug::Float; impl GLWEPlaintext { - pub fn encode_vec_i64(&mut self, data: &[i64], k: usize) { - let basek: usize = self.basek(); - self.data - .encode_vec_i64(basek, 0, k, data, i64::BITS as usize); + pub fn encode_vec_i64(&mut self, data: &[i64], k: TorusPrecision) { + let base2k: usize = self.base2k().into(); + self.data.encode_vec_i64(base2k, 0, k.into(), data); } - pub fn encode_coeff_i64(&mut self, data: i64, k: usize, idx: usize) { - let basek: usize = self.basek(); - self.data - .encode_coeff_i64(basek, 0, k, idx, data, i64::BITS as usize); + pub fn encode_coeff_i64(&mut self, data: i64, k: TorusPrecision, idx: usize) { + let base2k: usize = self.base2k().into(); + self.data.encode_coeff_i64(base2k, 0, k.into(), idx, data); } } impl GLWEPlaintext { - pub fn decode_vec_i64(&self, data: &mut [i64], k: usize) { - self.data.decode_vec_i64(self.basek(), 0, k, data); + pub fn decode_vec_i64(&self, data: &mut [i64], k: TorusPrecision) { + self.data + .decode_vec_i64(self.base2k().into(), 0, k.into(), data); } - pub fn decode_coeff_i64(&self, k: usize, idx: usize) -> i64 { - self.data.decode_coeff_i64(self.basek(), 0, k, idx) + pub fn decode_coeff_i64(&self, k: TorusPrecision, idx: usize) -> i64 { + self.data + .decode_coeff_i64(self.base2k().into(), 0, k.into(), idx) } pub fn decode_vec_float(&self, data: &mut [Float]) { - self.data.decode_vec_float(self.basek(), 0, data); + self.data.decode_vec_float(self.base2k().into(), 0, data); } pub fn std(&self) -> f64 { - self.data.std(self.basek(), 0) + self.data.std(self.base2k().into(), 0) } } impl LWEPlaintext { - pub fn encode_i64(&mut self, data: i64, k: usize) { - let basek: usize = self.basek(); - self.data.encode_i64(basek, k, data, i64::BITS as usize); + pub fn encode_i64(&mut self, data: i64, k: TorusPrecision) { + let base2k: usize = self.base2k().into(); + self.data.encode_i64(base2k, k.into(), data); } } impl LWEPlaintext { - pub fn decode_i64(&self, k: usize) -> i64 { - self.data.decode_i64(self.basek(), k) + pub fn decode_i64(&self, k: TorusPrecision) -> i64 { + self.data.decode_i64(self.base2k().into(), k.into()) } pub fn decode_float(&self) -> Float { - self.data.decode_float(self.basek()) + self.data.decode_float(self.base2k().into()) } } diff --git a/poulpy-hal/README.md b/poulpy-hal/README.md index 757e562..51af9ec 100644 --- a/poulpy-hal/README.md +++ b/poulpy-hal/README.md @@ -20,7 +20,7 @@ A `scalar_znx` is a front-end generic and backend agnostic type that stores a si #### VecZnx -A `vec_znx` is a front-end generic and backend agnostic type that stores a vector of small polynomials (i.e. a vector of scalars). Each polynomial is a `limb` that provides an additional `basek`-bits of precision in the Torus. For example a `vec_znx` with `n`=1024 `basek`=2 with 3 limbs can store 1024 coefficients with 36 bits of precision in the torus. In practice this type is used for LWE and GLWE ciphertexts/plaintexts. +A `vec_znx` is a front-end generic and backend agnostic type that stores a vector of small polynomials (i.e. a vector of scalars). Each polynomial is a `limb` that provides an additional `base2k`-bits of precision in the Torus. For example a `vec_znx` with `n`=1024 `base2k`=2 with 3 limbs can store 1024 coefficients with 36 bits of precision in the torus. In practice this type is used for LWE and GLWE ciphertexts/plaintexts. #### VecZnxDft diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index 8ef66eb..38901bf 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -98,10 +98,3 @@ pub trait TakeMatZnx { size: usize, ) -> (MatZnx<&mut [u8]>, &mut Self); } - -/// Take a slice of bytes from a [Scratch], wraps it into the template's type and returns it -/// as well as a new [Scratch] minus the taken array of bytes. -pub trait TakeLike<'a, B: Backend, T> { - type Output; - fn take_like(&'a mut self, template: &T) -> (Self::Output, &'a mut Self); -} diff --git a/poulpy-hal/src/api/vec_znx.rs b/poulpy-hal/src/api/vec_znx.rs index 0c8c7bc..afa30b8 100644 --- a/poulpy-hal/src/api/vec_znx.rs +++ b/poulpy-hal/src/api/vec_znx.rs @@ -9,16 +9,25 @@ pub trait VecZnxNormalizeTmpBytes { } pub trait VecZnxNormalize { + #[allow(clippy::too_many_arguments)] /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. - fn vec_znx_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where + fn vec_znx_normalize( + &self, + res_basek: usize, + res: &mut R, + res_col: usize, + a_basek: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxToMut, A: VecZnxToRef; } pub trait VecZnxNormalizeInplace { /// Normalizes the selected column of `a`. - fn vec_znx_normalize_inplace(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_normalize_inplace(&self, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut; } @@ -67,21 +76,21 @@ pub trait VecZnxSub { B: VecZnxToRef; } -pub trait VecZnxSubABInplace { +pub trait VecZnxSubInplace { /// Subtracts the selected column of `a` from the selected column of `res` inplace. /// /// res\[res_col\] -= a\[a_col\] - fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_sub_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef; } -pub trait VecZnxSubBAInplace { +pub trait VecZnxSubNegateInplace { /// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res` /// /// res\[res_col\] = a\[a_col\] - res\[res_col\] - fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_sub_negate_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef; @@ -127,8 +136,16 @@ pub trait VecZnxLshTmpBytes { pub trait VecZnxLsh { /// Left shift by k bits all columns of `a`. #[allow(clippy::too_many_arguments)] - fn vec_znx_lsh(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where + fn vec_znx_lsh( + &self, + base2k: usize, + k: usize, + r: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxToMut, A: VecZnxToRef; } @@ -140,22 +157,30 @@ pub trait VecZnxRshTmpBytes { pub trait VecZnxRsh { /// Right shift by k bits all columns of `a`. #[allow(clippy::too_many_arguments)] - fn vec_znx_rsh(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where + fn vec_znx_rsh( + &self, + base2k: usize, + k: usize, + r: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxToMut, A: VecZnxToRef; } pub trait VecZnxLshInplace { /// Left shift by k bits all columns of `a`. - fn vec_znx_lsh_inplace(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_lsh_inplace(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut; } pub trait VecZnxRshInplace { /// Right shift by k bits all columns of `a`. - fn vec_znx_rsh_inplace(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_rsh_inplace(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut; } @@ -264,8 +289,8 @@ pub trait VecZnxCopy { } pub trait VecZnxFillUniform { - /// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\] - fn vec_znx_fill_uniform(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + /// Fills the first `size` size with uniform values in \[-2^{base2k-1}, 2^{base2k-1}\] + fn vec_znx_fill_uniform(&self, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut; } @@ -274,7 +299,7 @@ pub trait VecZnxFillUniform { pub trait VecZnxFillNormal { fn vec_znx_fill_normal( &self, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -290,7 +315,7 @@ pub trait VecZnxAddNormal { /// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\]. fn vec_znx_add_normal( &self, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, diff --git a/poulpy-hal/src/api/vec_znx_big.rs b/poulpy-hal/src/api/vec_znx_big.rs index 09ff5b3..08159bb 100644 --- a/poulpy-hal/src/api/vec_znx_big.rs +++ b/poulpy-hal/src/api/vec_znx_big.rs @@ -30,7 +30,7 @@ pub trait VecZnxBigFromBytes { /// Add a discrete normal distribution on res. /// /// # Arguments -/// * `basek`: base two logarithm of the bivariate representation +/// * `base2k`: base two logarithm of the bivariate representation /// * `res`: receiver. /// * `res_col`: column of the receiver on which the operation is performed/stored. /// * `k`: @@ -40,7 +40,7 @@ pub trait VecZnxBigFromBytes { pub trait VecZnxBigAddNormal { fn vec_znx_big_add_normal>( &self, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -93,17 +93,17 @@ pub trait VecZnxBigSub { C: VecZnxBigToRef; } -pub trait VecZnxBigSubABInplace { +pub trait VecZnxBigSubInplace { /// Subtracts `a` from `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef; } -pub trait VecZnxBigSubBAInplace { +pub trait VecZnxBigSubNegateInplace { /// Subtracts `b` from `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_negate_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef; @@ -118,9 +118,9 @@ pub trait VecZnxBigSubSmallA { C: VecZnxBigToRef; } -pub trait VecZnxBigSubSmallAInplace { +pub trait VecZnxBigSubSmallInplace { /// Subtracts `a` from `res` and stores the result on `res`. - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef; @@ -135,9 +135,9 @@ pub trait VecZnxBigSubSmallB { C: VecZnxToRef; } -pub trait VecZnxBigSubSmallBInplace { +pub trait VecZnxBigSubSmallNegateInplace { /// Subtracts `res` from `a` and stores the result on `res`. - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_negate_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef; @@ -160,12 +160,14 @@ pub trait VecZnxBigNormalizeTmpBytes { fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; } +#[allow(clippy::too_many_arguments)] pub trait VecZnxBigNormalize { fn vec_znx_big_normalize( &self, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, diff --git a/poulpy-hal/src/api/vec_znx_dft.rs b/poulpy-hal/src/api/vec_znx_dft.rs index 588ec53..58589c3 100644 --- a/poulpy-hal/src/api/vec_znx_dft.rs +++ b/poulpy-hal/src/api/vec_znx_dft.rs @@ -68,15 +68,15 @@ pub trait VecZnxDftSub { D: VecZnxDftToRef; } -pub trait VecZnxDftSubABInplace { - fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub trait VecZnxDftSubInplace { + fn vec_znx_dft_sub_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef; } -pub trait VecZnxDftSubBAInplace { - fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub trait VecZnxDftSubNegateInplace { + fn vec_znx_dft_sub_negate_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef; diff --git a/poulpy-hal/src/api/zn.rs b/poulpy-hal/src/api/zn.rs index e7c4ef5..9e8e308 100644 --- a/poulpy-hal/src/api/zn.rs +++ b/poulpy-hal/src/api/zn.rs @@ -12,14 +12,14 @@ pub trait ZnNormalizeTmpBytes { pub trait ZnNormalizeInplace { /// Normalizes the selected column of `a`. - fn zn_normalize_inplace(&self, n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) + fn zn_normalize_inplace(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) where R: ZnToMut; } pub trait ZnFillUniform { - /// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\] - fn zn_fill_uniform(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + /// Fills the first `size` size with uniform values in \[-2^{base2k-1}, 2^{base2k-1}\] + fn zn_fill_uniform(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut; } @@ -29,7 +29,7 @@ pub trait ZnFillNormal { fn zn_fill_normal( &self, n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -46,7 +46,7 @@ pub trait ZnAddNormal { fn zn_add_normal( &self, n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, diff --git a/poulpy-hal/src/bench_suite/svp.rs b/poulpy-hal/src/bench_suite/svp.rs index 7007d9a..6781430 100644 --- a/poulpy-hal/src/bench_suite/svp.rs +++ b/poulpy-hal/src/bench_suite/svp.rs @@ -15,16 +15,16 @@ use crate::{ pub fn bench_svp_prepare(c: &mut Criterion, label: &str) where Module: SvpPrepare + SvpPPolAlloc + ModuleNew, - B: Backend, + B: Backend, { - let group_name: String = format!("svp_prepare::{}", label); + let group_name: String = format!("svp_prepare::{label}"); let mut group = c.benchmark_group(group_name); fn runner(log_n: usize) -> impl FnMut() where Module: SvpPrepare + SvpPPolAlloc + ModuleNew, - B: Backend, + B: Backend, { let module: Module = Module::::new(1 << log_n); @@ -53,16 +53,16 @@ where pub fn bench_svp_apply_dft(c: &mut Criterion, label: &str) where Module: SvpApplyDft + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, - B: Backend, + B: Backend, { - let group_name: String = format!("svp_apply_dft::{}", label); + let group_name: String = format!("svp_apply_dft::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where Module: SvpApplyDft + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, - B: Backend, + B: Backend, { let n: usize = 1 << params[0]; let cols: usize = params[1]; @@ -100,16 +100,16 @@ where pub fn bench_svp_apply_dft_to_dft(c: &mut Criterion, label: &str) where Module: SvpApplyDftToDft + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, - B: Backend, + B: Backend, { - let group_name: String = format!("svp_apply_dft_to_dft::{}", label); + let group_name: String = format!("svp_apply_dft_to_dft::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where Module: SvpApplyDftToDft + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, - B: Backend, + B: Backend, { let n: usize = 1 << params[0]; let cols: usize = params[1]; @@ -147,16 +147,16 @@ where pub fn bench_svp_apply_dft_to_dft_add(c: &mut Criterion, label: &str) where Module: SvpApplyDftToDftAdd + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, - B: Backend, + B: Backend, { - let group_name: String = format!("svp_apply_dft_to_dft_add::{}", label); + let group_name: String = format!("svp_apply_dft_to_dft_add::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where Module: SvpApplyDftToDftAdd + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, - B: Backend, + B: Backend, { let n: usize = 1 << params[0]; let cols: usize = params[1]; @@ -194,16 +194,16 @@ where pub fn bench_svp_apply_dft_to_dft_inplace(c: &mut Criterion, label: &str) where Module: SvpApplyDftToDftInplace + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, - B: Backend, + B: Backend, { - let group_name: String = format!("svp_apply_dft_to_dft_inplace::{}", label); + let group_name: String = format!("svp_apply_dft_to_dft_inplace::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where Module: SvpApplyDftToDftInplace + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, - B: Backend, + B: Backend, { let n: usize = 1 << params[0]; let cols: usize = params[1]; diff --git a/poulpy-hal/src/bench_suite/vec_znx_big.rs b/poulpy-hal/src/bench_suite/vec_znx_big.rs index 2f05b35..01e6812 100644 --- a/poulpy-hal/src/bench_suite/vec_znx_big.rs +++ b/poulpy-hal/src/bench_suite/vec_znx_big.rs @@ -8,7 +8,7 @@ use crate::{ ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, + VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubInplace, VecZnxBigSubNegateInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallB, }, layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig}, @@ -19,7 +19,7 @@ pub fn bench_vec_znx_big_add(c: &mut Criterion, label: &str) where Module: VecZnxBigAdd + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_add::{}", label); + let group_name: String = format!("vec_znx_big_add::{label}"); let mut group = c.benchmark_group(group_name); @@ -65,7 +65,7 @@ pub fn bench_vec_znx_big_add_inplace(c: &mut Criterion, label: &str) where Module: VecZnxBigAddInplace + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_add_inplace::{}", label); + let group_name: String = format!("vec_znx_big_add_inplace::{label}"); let mut group = c.benchmark_group(group_name); @@ -109,7 +109,7 @@ pub fn bench_vec_znx_big_add_small(c: &mut Criterion, label: &str) where Module: VecZnxBigAddSmall + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_add_small::{}", label); + let group_name: String = format!("vec_znx_big_add_small::{label}"); let mut group = c.benchmark_group(group_name); @@ -155,7 +155,7 @@ pub fn bench_vec_znx_big_add_small_inplace(c: &mut Criterion, label: where Module: VecZnxBigAddSmallInplace + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_add_small_inplace::{}", label); + let group_name: String = format!("vec_znx_big_add_small_inplace::{label}"); let mut group = c.benchmark_group(group_name); @@ -199,7 +199,7 @@ pub fn bench_vec_znx_big_automorphism(c: &mut Criterion, label: &str where Module: VecZnxBigAutomorphism + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_automorphism::{}", label); + let group_name: String = format!("vec_znx_big_automorphism::{label}"); let mut group = c.benchmark_group(group_name); @@ -244,7 +244,7 @@ where Module: VecZnxBigAutomorphismInplace + VecZnxBigAutomorphismInplaceTmpBytes + ModuleNew + VecZnxBigAlloc, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_automorphism_inplace::{}", label); + let group_name: String = format!("vec_znx_automorphism_inplace::{label}"); let mut group = c.benchmark_group(group_name); @@ -289,7 +289,7 @@ pub fn bench_vec_znx_big_negate(c: &mut Criterion, label: &str) where Module: VecZnxBigNegate + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_negate::{}", label); + let group_name: String = format!("vec_znx_big_negate::{label}"); let mut group = c.benchmark_group(group_name); @@ -332,7 +332,7 @@ pub fn bench_vec_znx_big_negate_inplace(c: &mut Criterion, label: &s where Module: VecZnxBigNegateInplace + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_negate_big_inplace::{}", label); + let group_name: String = format!("vec_znx_negate_big_inplace::{label}"); let mut group = c.benchmark_group(group_name); @@ -374,7 +374,7 @@ where Module: VecZnxBigNormalize + ModuleNew + VecZnxBigNormalizeTmpBytes + VecZnxBigAlloc, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_big_normalize::{}", label); + let group_name: String = format!("vec_znx_big_normalize::{label}"); let mut group = c.benchmark_group(group_name); @@ -389,7 +389,7 @@ where let module: Module = Module::::new(n as u64); - let basek: usize = 50; + let base2k: usize = 50; let mut source: Source = Source::new([0u8; 32]); @@ -404,7 +404,7 @@ where move || { for i in 0..cols { - module.vec_znx_big_normalize(basek, &mut res, i, &a, i, scratch.borrow()); + module.vec_znx_big_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow()); } black_box(()); } @@ -423,7 +423,7 @@ pub fn bench_vec_znx_big_sub(c: &mut Criterion, label: &str) where Module: VecZnxBigSub + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_sub::{}", label); + let group_name: String = format!("vec_znx_big_sub::{label}"); let mut group = c.benchmark_group(group_name); @@ -464,17 +464,17 @@ where group.finish(); } -pub fn bench_vec_znx_big_sub_ab_inplace(c: &mut Criterion, label: &str) +pub fn bench_vec_znx_big_sub_inplace(c: &mut Criterion, label: &str) where - Module: VecZnxBigSubABInplace + ModuleNew + VecZnxBigAlloc, + Module: VecZnxBigSubInplace + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_sub_inplace::{}", label); + let group_name: String = format!("vec_znx_big_sub_inplace::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where - Module: VecZnxBigSubABInplace + ModuleNew + VecZnxBigAlloc, + Module: VecZnxBigSubInplace + ModuleNew + VecZnxBigAlloc, { let module: Module = Module::::new(1 << params[0]); @@ -492,7 +492,7 @@ where move || { for i in 0..cols { - module.vec_znx_big_sub_ab_inplace(&mut c, i, &a, i); + module.vec_znx_big_sub_inplace(&mut c, i, &a, i); } black_box(()); } @@ -507,17 +507,17 @@ where group.finish(); } -pub fn bench_vec_znx_big_sub_ba_inplace(c: &mut Criterion, label: &str) +pub fn bench_vec_znx_big_sub_negate_inplace(c: &mut Criterion, label: &str) where - Module: VecZnxBigSubBAInplace + ModuleNew + VecZnxBigAlloc, + Module: VecZnxBigSubNegateInplace + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_sub_inplace::{}", label); + let group_name: String = format!("vec_znx_big_sub_inplace::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where - Module: VecZnxBigSubBAInplace + ModuleNew + VecZnxBigAlloc, + Module: VecZnxBigSubNegateInplace + ModuleNew + VecZnxBigAlloc, { let module: Module = Module::::new(1 << params[0]); @@ -535,7 +535,7 @@ where move || { for i in 0..cols { - module.vec_znx_big_sub_ba_inplace(&mut c, i, &a, i); + module.vec_znx_big_sub_negate_inplace(&mut c, i, &a, i); } black_box(()); } @@ -554,7 +554,7 @@ pub fn bench_vec_znx_big_sub_small_a(c: &mut Criterion, label: &str) where Module: VecZnxBigSubSmallA + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_sub_small_a::{}", label); + let group_name: String = format!("vec_znx_big_sub_small_a::{label}"); let mut group = c.benchmark_group(group_name); @@ -599,7 +599,7 @@ pub fn bench_vec_znx_big_sub_small_b(c: &mut Criterion, label: &str) where Module: VecZnxBigSubSmallB + ModuleNew + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_big_sub_small_b::{}", label); + let group_name: String = format!("vec_znx_big_sub_small_b::{label}"); let mut group = c.benchmark_group(group_name); diff --git a/poulpy-hal/src/bench_suite/vec_znx_dft.rs b/poulpy-hal/src/bench_suite/vec_znx_dft.rs index ac2f758..a5895ab 100644 --- a/poulpy-hal/src/bench_suite/vec_znx_dft.rs +++ b/poulpy-hal/src/bench_suite/vec_znx_dft.rs @@ -6,7 +6,7 @@ use rand::RngCore; use crate::{ api::{ ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, - VecZnxDftApply, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyTmpA, + VecZnxDftApply, VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxIdftApply, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, }, layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft}, @@ -17,7 +17,7 @@ pub fn bench_vec_znx_dft_add(c: &mut Criterion, label: &str) where Module: VecZnxDftAdd + ModuleNew + VecZnxDftAlloc, { - let group_name: String = format!("vec_znx_dft_add::{}", label); + let group_name: String = format!("vec_znx_dft_add::{label}"); let mut group = c.benchmark_group(group_name); @@ -62,7 +62,7 @@ pub fn bench_vec_znx_dft_add_inplace(c: &mut Criterion, label: &str) where Module: VecZnxDftAddInplace + ModuleNew + VecZnxDftAlloc, { - let group_name: String = format!("vec_znx_dft_add_inplace::{}", label); + let group_name: String = format!("vec_znx_dft_add_inplace::{label}"); let mut group = c.benchmark_group(group_name); @@ -106,7 +106,7 @@ pub fn bench_vec_znx_dft_apply(c: &mut Criterion, label: &str) where Module: VecZnxDftApply + ModuleNew + VecZnxDftAlloc, { - let group_name: String = format!("vec_znx_dft_apply::{}", label); + let group_name: String = format!("vec_znx_dft_apply::{label}"); let mut group = c.benchmark_group(group_name); @@ -149,7 +149,7 @@ where Module: VecZnxIdftApply + ModuleNew + VecZnxIdftApplyTmpBytes + VecZnxDftAlloc + VecZnxBigAlloc, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_idft_apply::{}", label); + let group_name: String = format!("vec_znx_idft_apply::{label}"); let mut group = c.benchmark_group(group_name); @@ -194,7 +194,7 @@ pub fn bench_vec_znx_idft_apply_tmpa(c: &mut Criterion, label: &str) where Module: VecZnxIdftApplyTmpA + ModuleNew + VecZnxDftAlloc + VecZnxBigAlloc, { - let group_name: String = format!("vec_znx_idft_apply_tmpa::{}", label); + let group_name: String = format!("vec_znx_idft_apply_tmpa::{label}"); let mut group = c.benchmark_group(group_name); @@ -235,7 +235,7 @@ pub fn bench_vec_znx_dft_sub(c: &mut Criterion, label: &str) where Module: VecZnxDftSub + ModuleNew + VecZnxDftAlloc, { - let group_name: String = format!("vec_znx_dft_sub::{}", label); + let group_name: String = format!("vec_znx_dft_sub::{label}"); let mut group = c.benchmark_group(group_name); @@ -276,17 +276,17 @@ where group.finish(); } -pub fn bench_vec_znx_dft_sub_ab_inplace(c: &mut Criterion, label: &str) +pub fn bench_vec_znx_dft_sub_inplace(c: &mut Criterion, label: &str) where - Module: VecZnxDftSubABInplace + ModuleNew + VecZnxDftAlloc, + Module: VecZnxDftSubInplace + ModuleNew + VecZnxDftAlloc, { - let group_name: String = format!("vec_znx_dft_sub_ab_inplace::{}", label); + let group_name: String = format!("vec_znx_dft_sub_inplace::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where - Module: VecZnxDftSubABInplace + ModuleNew + VecZnxDftAlloc, + Module: VecZnxDftSubInplace + ModuleNew + VecZnxDftAlloc, { let n: usize = params[0]; let cols: usize = params[1]; @@ -305,7 +305,7 @@ where move || { for i in 0..cols { - module.vec_znx_dft_sub_ab_inplace(&mut c, i, &a, i); + module.vec_znx_dft_sub_inplace(&mut c, i, &a, i); } black_box(()); } @@ -320,17 +320,17 @@ where group.finish(); } -pub fn bench_vec_znx_dft_sub_ba_inplace(c: &mut Criterion, label: &str) +pub fn bench_vec_znx_dft_sub_negate_inplace(c: &mut Criterion, label: &str) where - Module: VecZnxDftSubBAInplace + ModuleNew + VecZnxDftAlloc, + Module: VecZnxDftSubNegateInplace + ModuleNew + VecZnxDftAlloc, { - let group_name: String = format!("vec_znx_dft_sub_ba_inplace::{}", label); + let group_name: String = format!("vec_znx_dft_sub_negate_inplace::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where - Module: VecZnxDftSubBAInplace + ModuleNew + VecZnxDftAlloc, + Module: VecZnxDftSubNegateInplace + ModuleNew + VecZnxDftAlloc, { let n: usize = params[0]; let cols: usize = params[1]; @@ -349,7 +349,7 @@ where move || { for i in 0..cols { - module.vec_znx_dft_sub_ba_inplace(&mut c, i, &a, i); + module.vec_znx_dft_sub_negate_inplace(&mut c, i, &a, i); } black_box(()); } diff --git a/poulpy-hal/src/bench_suite/vmp.rs b/poulpy-hal/src/bench_suite/vmp.rs index 0fa2ff9..2686d7a 100644 --- a/poulpy-hal/src/bench_suite/vmp.rs +++ b/poulpy-hal/src/bench_suite/vmp.rs @@ -17,7 +17,7 @@ where Module: ModuleNew + VmpPMatAlloc + VmpPrepare + VmpPrepareTmpBytes, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vmp_prepare::{}", label); + let group_name: String = format!("vmp_prepare::{label}"); let mut group = c.benchmark_group(group_name); @@ -76,7 +76,7 @@ where Module: ModuleNew + VmpApplyDftTmpBytes + VmpApplyDft + VmpPMatAlloc + VecZnxDftAlloc, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vmp_apply_dft::{}", label); + let group_name: String = format!("vmp_apply_dft::{label}"); let mut group = c.benchmark_group(group_name); @@ -137,7 +137,7 @@ where Module: ModuleNew + VecZnxDftAlloc + VmpPMatAlloc + VmpApplyDftToDft + VmpApplyDftToDftTmpBytes, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vmp_apply_dft_to_dft::{}", label); + let group_name: String = format!("vmp_apply_dft_to_dft::{label}"); let mut group = c.benchmark_group(group_name); @@ -200,7 +200,7 @@ where Module: ModuleNew + VecZnxDftAlloc + VmpPMatAlloc + VmpApplyDftToDftAdd + VmpApplyDftToDftAddTmpBytes, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vmp_apply_dft_to_dft_add::{}", label); + let group_name: String = format!("vmp_apply_dft_to_dft_add::{label}"); let mut group = c.benchmark_group(group_name); diff --git a/poulpy-hal/src/delegates/scratch.rs b/poulpy-hal/src/delegates/scratch.rs index 95c5b92..ac022e3 100644 --- a/poulpy-hal/src/delegates/scratch.rs +++ b/poulpy-hal/src/delegates/scratch.rs @@ -1,11 +1,11 @@ use crate::{ api::{ - ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx, - TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat, + ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeMatZnx, TakeScalarZnx, TakeSlice, + TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat, }, - layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, oep::{ - ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeLikeImpl, TakeMatZnxImpl, + ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, }, @@ -156,80 +156,3 @@ where B::take_mat_znx_impl(self, n, rows, cols_in, cols_out, size) } } - -impl<'a, B: Backend, D> TakeLike<'a, B, ScalarZnx> for Scratch -where - B: TakeLikeImpl<'a, B, ScalarZnx, Output = ScalarZnx<&'a mut [u8]>>, - D: DataRef, -{ - type Output = ScalarZnx<&'a mut [u8]>; - fn take_like(&'a mut self, template: &ScalarZnx) -> (Self::Output, &'a mut Self) { - B::take_like_impl(self, template) - } -} - -impl<'a, B: Backend, D> TakeLike<'a, B, SvpPPol> for Scratch -where - B: TakeLikeImpl<'a, B, SvpPPol, Output = SvpPPol<&'a mut [u8], B>>, - D: DataRef, -{ - type Output = SvpPPol<&'a mut [u8], B>; - fn take_like(&'a mut self, template: &SvpPPol) -> (Self::Output, &'a mut Self) { - B::take_like_impl(self, template) - } -} - -impl<'a, B: Backend, D> TakeLike<'a, B, VecZnx> for Scratch -where - B: TakeLikeImpl<'a, B, VecZnx, Output = VecZnx<&'a mut [u8]>>, - D: DataRef, -{ - type Output = VecZnx<&'a mut [u8]>; - fn take_like(&'a mut self, template: &VecZnx) -> (Self::Output, &'a mut Self) { - B::take_like_impl(self, template) - } -} - -impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxBig> for Scratch -where - B: TakeLikeImpl<'a, B, VecZnxBig, Output = VecZnxBig<&'a mut [u8], B>>, - D: DataRef, -{ - type Output = VecZnxBig<&'a mut [u8], B>; - fn take_like(&'a mut self, template: &VecZnxBig) -> (Self::Output, &'a mut Self) { - B::take_like_impl(self, template) - } -} - -impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxDft> for Scratch -where - B: TakeLikeImpl<'a, B, VecZnxDft, Output = VecZnxDft<&'a mut [u8], B>>, - D: DataRef, -{ - type Output = VecZnxDft<&'a mut [u8], B>; - fn take_like(&'a mut self, template: &VecZnxDft) -> (Self::Output, &'a mut Self) { - B::take_like_impl(self, template) - } -} - -impl<'a, B: Backend, D> TakeLike<'a, B, MatZnx> for Scratch -where - B: TakeLikeImpl<'a, B, MatZnx, Output = MatZnx<&'a mut [u8]>>, - D: DataRef, -{ - type Output = MatZnx<&'a mut [u8]>; - fn take_like(&'a mut self, template: &MatZnx) -> (Self::Output, &'a mut Self) { - B::take_like_impl(self, template) - } -} - -impl<'a, B: Backend, D> TakeLike<'a, B, VmpPMat> for Scratch -where - B: TakeLikeImpl<'a, B, VmpPMat, Output = VmpPMat<&'a mut [u8], B>>, - D: DataRef, -{ - type Output = VmpPMat<&'a mut [u8], B>; - fn take_like(&'a mut self, template: &VmpPMat) -> (Self::Output, &'a mut Self) { - B::take_like_impl(self, template) - } -} diff --git a/poulpy-hal/src/delegates/vec_znx.rs b/poulpy-hal/src/delegates/vec_znx.rs index a5cd36f..60a961e 100644 --- a/poulpy-hal/src/delegates/vec_znx.rs +++ b/poulpy-hal/src/delegates/vec_znx.rs @@ -5,8 +5,8 @@ use crate::{ VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, - VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubABInplace, - VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, + VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, + VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, }, layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, oep::{ @@ -17,7 +17,7 @@ use crate::{ VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, - VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl, + VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, }, source::Source, @@ -36,12 +36,21 @@ impl VecZnxNormalize for Module where B: Backend + VecZnxNormalizeImpl, { - fn vec_znx_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where + #[allow(clippy::too_many_arguments)] + fn vec_znx_normalize( + &self, + res_basek: usize, + res: &mut R, + res_col: usize, + a_basek: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxToMut, A: VecZnxToRef, { - B::vec_znx_normalize_impl(self, basek, res, res_col, a, a_col, scratch) + B::vec_znx_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch) } } @@ -49,11 +58,11 @@ impl VecZnxNormalizeInplace for Module where B: Backend + VecZnxNormalizeInplaceImpl, { - fn vec_znx_normalize_inplace(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_normalize_inplace(&self, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut, { - B::vec_znx_normalize_inplace_impl(self, basek, a, a_col, scratch) + B::vec_znx_normalize_inplace_impl(self, base2k, a, a_col, scratch) } } @@ -125,29 +134,29 @@ where } } -impl VecZnxSubABInplace for Module +impl VecZnxSubInplace for Module where - B: Backend + VecZnxSubABInplaceImpl, + B: Backend + VecZnxSubInplaceImpl, { - fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_sub_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, { - B::vec_znx_sub_ab_inplace_impl(self, res, res_col, a, a_col) + B::vec_znx_sub_inplace_impl(self, res, res_col, a, a_col) } } -impl VecZnxSubBAInplace for Module +impl VecZnxSubNegateInplace for Module where - B: Backend + VecZnxSubBAInplaceImpl, + B: Backend + VecZnxSubNegateInplaceImpl, { - fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_sub_negate_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, { - B::vec_znx_sub_ba_inplace_impl(self, res, res_col, a, a_col) + B::vec_znx_sub_negate_inplace_impl(self, res, res_col, a, a_col) } } @@ -227,7 +236,7 @@ where { fn vec_znx_lsh( &self, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -238,7 +247,7 @@ where R: VecZnxToMut, A: VecZnxToRef, { - B::vec_znx_lsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch); + B::vec_znx_lsh_impl(self, base2k, k, res, res_col, a, a_col, scratch); } } @@ -248,7 +257,7 @@ where { fn vec_znx_rsh( &self, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -259,7 +268,7 @@ where R: VecZnxToMut, A: VecZnxToRef, { - B::vec_znx_rsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch); + B::vec_znx_rsh_impl(self, base2k, k, res, res_col, a, a_col, scratch); } } @@ -267,11 +276,11 @@ impl VecZnxLshInplace for Module where B: Backend + VecZnxLshInplaceImpl, { - fn vec_znx_lsh_inplace(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_lsh_inplace(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut, { - B::vec_znx_lsh_inplace_impl(self, basek, k, a, a_col, scratch) + B::vec_znx_lsh_inplace_impl(self, base2k, k, a, a_col, scratch) } } @@ -279,11 +288,11 @@ impl VecZnxRshInplace for Module where B: Backend + VecZnxRshInplaceImpl, { - fn vec_znx_rsh_inplace(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_rsh_inplace(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut, { - B::vec_znx_rsh_inplace_impl(self, basek, k, a, a_col, scratch) + B::vec_znx_rsh_inplace_impl(self, base2k, k, a, a_col, scratch) } } @@ -463,11 +472,11 @@ impl VecZnxFillUniform for Module where B: Backend + VecZnxFillUniformImpl, { - fn vec_znx_fill_uniform(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn vec_znx_fill_uniform(&self, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut, { - B::vec_znx_fill_uniform_impl(self, basek, res, res_col, source); + B::vec_znx_fill_uniform_impl(self, base2k, res, res_col, source); } } @@ -477,7 +486,7 @@ where { fn vec_znx_fill_normal( &self, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -487,7 +496,7 @@ where ) where R: VecZnxToMut, { - B::vec_znx_fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound); + B::vec_znx_fill_normal_impl(self, base2k, res, res_col, k, source, sigma, bound); } } @@ -497,7 +506,7 @@ where { fn vec_znx_add_normal( &self, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -507,6 +516,6 @@ where ) where R: VecZnxToMut, { - B::vec_znx_add_normal_impl(self, basek, res, res_col, k, source, sigma, bound); + B::vec_znx_add_normal_impl(self, base2k, res, res_col, k, source, sigma, bound); } } diff --git a/poulpy-hal/src/delegates/vec_znx_big.rs b/poulpy-hal/src/delegates/vec_znx_big.rs index 1d0f8f1..1556a87 100644 --- a/poulpy-hal/src/delegates/vec_znx_big.rs +++ b/poulpy-hal/src/delegates/vec_znx_big.rs @@ -3,17 +3,17 @@ use crate::{ VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigFromBytes, VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, - VecZnxBigSubSmallAInplace, VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace, + VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubInplace, VecZnxBigSubNegateInplace, VecZnxBigSubSmallA, + VecZnxBigSubSmallB, VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, }, layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef}, oep::{ VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, - VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, - VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, - VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl, + VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl, + VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl, }, source::Source, }; @@ -64,7 +64,7 @@ where { fn vec_znx_big_add_normal>( &self, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -72,7 +72,7 @@ where sigma: f64, bound: f64, ) { - B::add_normal_impl(self, basek, res, res_col, k, source, sigma, bound); + B::add_normal_impl(self, base2k, res, res_col, k, source, sigma, bound); } } @@ -144,29 +144,29 @@ where } } -impl VecZnxBigSubABInplace for Module +impl VecZnxBigSubInplace for Module where - B: Backend + VecZnxBigSubABInplaceImpl, + B: Backend + VecZnxBigSubInplaceImpl, { - fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef, { - B::vec_znx_big_sub_ab_inplace_impl(self, res, res_col, a, a_col); + B::vec_znx_big_sub_inplace_impl(self, res, res_col, a, a_col); } } -impl VecZnxBigSubBAInplace for Module +impl VecZnxBigSubNegateInplace for Module where - B: Backend + VecZnxBigSubBAInplaceImpl, + B: Backend + VecZnxBigSubNegateInplaceImpl, { - fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_negate_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef, { - B::vec_znx_big_sub_ba_inplace_impl(self, res, res_col, a, a_col); + B::vec_znx_big_sub_negate_inplace_impl(self, res, res_col, a, a_col); } } @@ -184,16 +184,16 @@ where } } -impl VecZnxBigSubSmallAInplace for Module +impl VecZnxBigSubSmallInplace for Module where - B: Backend + VecZnxBigSubSmallAInplaceImpl, + B: Backend + VecZnxBigSubSmallInplaceImpl, { - fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef, { - B::vec_znx_big_sub_small_a_inplace_impl(self, res, res_col, a, a_col); + B::vec_znx_big_sub_small_inplace_impl(self, res, res_col, a, a_col); } } @@ -211,16 +211,16 @@ where } } -impl VecZnxBigSubSmallBInplace for Module +impl VecZnxBigSubSmallNegateInplace for Module where - B: Backend + VecZnxBigSubSmallBInplaceImpl, + B: Backend + VecZnxBigSubSmallNegateInplaceImpl, { - fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_sub_small_negate_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef, { - B::vec_znx_big_sub_small_b_inplace_impl(self, res, res_col, a, a_col); + B::vec_znx_big_sub_small_negate_inplace_impl(self, res, res_col, a, a_col); } } @@ -264,9 +264,10 @@ where { fn vec_znx_big_normalize( &self, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -274,7 +275,7 @@ where R: VecZnxToMut, A: VecZnxBigToRef, { - B::vec_znx_big_normalize_impl(self, basek, res, res_col, a, a_col, scratch); + B::vec_znx_big_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch); } } diff --git a/poulpy-hal/src/delegates/vec_znx_dft.rs b/poulpy-hal/src/delegates/vec_znx_dft.rs index b486b08..3736e34 100644 --- a/poulpy-hal/src/delegates/vec_znx_dft.rs +++ b/poulpy-hal/src/delegates/vec_znx_dft.rs @@ -1,7 +1,7 @@ use crate::{ api::{ VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, - VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIdftApply, + VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, }, layouts::{ @@ -10,7 +10,7 @@ use crate::{ }, oep::{ VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, - VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl, VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, }, }; @@ -143,29 +143,29 @@ where } } -impl VecZnxDftSubABInplace for Module +impl VecZnxDftSubInplace for Module where - B: Backend + VecZnxDftSubABInplaceImpl, + B: Backend + VecZnxDftSubInplaceImpl, { - fn vec_znx_dft_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_dft_sub_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, { - B::vec_znx_dft_sub_ab_inplace_impl(self, res, res_col, a, a_col); + B::vec_znx_dft_sub_inplace_impl(self, res, res_col, a, a_col); } } -impl VecZnxDftSubBAInplace for Module +impl VecZnxDftSubNegateInplace for Module where - B: Backend + VecZnxDftSubBAInplaceImpl, + B: Backend + VecZnxDftSubNegateInplaceImpl, { - fn vec_znx_dft_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_dft_sub_negate_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, { - B::vec_znx_dft_sub_ba_inplace_impl(self, res, res_col, a, a_col); + B::vec_znx_dft_sub_negate_inplace_impl(self, res, res_col, a, a_col); } } diff --git a/poulpy-hal/src/delegates/zn.rs b/poulpy-hal/src/delegates/zn.rs index 450bdc9..6a4c999 100644 --- a/poulpy-hal/src/delegates/zn.rs +++ b/poulpy-hal/src/delegates/zn.rs @@ -18,11 +18,11 @@ impl ZnNormalizeInplace for Module where B: Backend + ZnNormalizeInplaceImpl, { - fn zn_normalize_inplace(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn zn_normalize_inplace(&self, n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: ZnToMut, { - B::zn_normalize_inplace_impl(n, basek, a, a_col, scratch) + B::zn_normalize_inplace_impl(n, base2k, a, a_col, scratch) } } @@ -30,11 +30,11 @@ impl ZnFillUniform for Module where B: Backend + ZnFillUniformImpl, { - fn zn_fill_uniform(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn zn_fill_uniform(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut, { - B::zn_fill_uniform_impl(n, basek, res, res_col, source); + B::zn_fill_uniform_impl(n, base2k, res, res_col, source); } } @@ -45,7 +45,7 @@ where fn zn_fill_normal( &self, n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -55,7 +55,7 @@ where ) where R: ZnToMut, { - B::zn_fill_normal_impl(n, basek, res, res_col, k, source, sigma, bound); + B::zn_fill_normal_impl(n, base2k, res, res_col, k, source, sigma, bound); } } @@ -66,7 +66,7 @@ where fn zn_add_normal( &self, n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -76,6 +76,6 @@ where ) where R: ZnToMut, { - B::zn_add_normal_impl(n, basek, res, res_col, k, source, sigma, bound); + B::zn_add_normal_impl(n, base2k, res, res_col, k, source, sigma, bound); } } diff --git a/poulpy-hal/src/layouts/encoding.rs b/poulpy-hal/src/layouts/encoding.rs index 957d835..4d61f93 100644 --- a/poulpy-hal/src/layouts/encoding.rs +++ b/poulpy-hal/src/layouts/encoding.rs @@ -3,65 +3,108 @@ use rug::{Assign, Float}; use crate::{ layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut}, - reference::znx::znx_zero_ref, + reference::znx::{ + ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef, ZnxZero, + get_carry_i128, get_digit_i128, znx_zero_ref, + }, }; impl VecZnx { - pub fn encode_vec_i64(&mut self, basek: usize, col: usize, k: usize, data: &[i64], log_max: usize) { - let size: usize = k.div_ceil(basek); + pub fn encode_vec_i64(&mut self, base2k: usize, col: usize, k: usize, data: &[i64]) { + let size: usize = k.div_ceil(base2k); #[cfg(debug_assertions)] { let a: VecZnx<&mut [u8]> = self.to_mut(); assert!( size <= a.size(), - "invalid argument k.div_ceil(basek)={} > a.size()={}", + "invalid argument k.div_ceil(base2k)={} > a.size()={}", size, a.size() ); assert!(col < a.cols()); - assert!(data.len() <= a.n()) + assert!(data.len() == a.n()) } - let data_len: usize = data.len(); let mut a: VecZnx<&mut [u8]> = self.to_mut(); - let k_rem: usize = basek - (k % basek); + let a_size: usize = a.size(); // Zeroes coefficients of the i-th column - (0..a.size()).for_each(|i| { + for i in 0..a_size { znx_zero_ref(a.at_mut(col, i)); - }); - - // If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy - // values on the last limb. - // Else we decompose values base2k. - if log_max + k_rem < 63 || k_rem == basek { - a.at_mut(col, size - 1)[..data_len].copy_from_slice(&data[..data_len]); - } else { - let mask: i64 = (1 << basek) - 1; - let steps: usize = size.min(log_max.div_ceil(basek)); - (size - steps..size) - .rev() - .enumerate() - .for_each(|(i, i_rev)| { - let shift: usize = i * basek; - izip!(a.at_mut(col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); - }) } - // Case where self.prec % self.k != 0. - if k_rem != basek { - let steps: usize = size.min(log_max.div_ceil(basek)); - (size - steps..size).rev().for_each(|i| { - a.at_mut(col, i)[..data_len] - .iter_mut() - .for_each(|x| *x <<= k_rem); - }) + // Copies the data on the correct limb + a.at_mut(col, size - 1).copy_from_slice(data); + + let mut carry: Vec = vec![0i64; a.n()]; + let k_rem: usize = (base2k - (k % base2k)) % base2k; + + // Normalizes and shift if necessary. + for j in (0..size).rev() { + if j == size - 1 { + ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry); + } else if j == 0 { + ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry); + } else { + ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry); + } } } - pub fn encode_coeff_i64(&mut self, basek: usize, col: usize, k: usize, idx: usize, data: i64, log_max: usize) { - let size: usize = k.div_ceil(basek); + pub fn encode_vec_i128(&mut self, base2k: usize, col: usize, k: usize, data: &[i128]) { + let size: usize = k.div_ceil(base2k); + + #[cfg(debug_assertions)] + { + let a: VecZnx<&mut [u8]> = self.to_mut(); + assert!( + size <= a.size(), + "invalid argument k.div_ceil(base2k)={} > a.size()={}", + size, + a.size() + ); + assert!(col < a.cols()); + assert!(data.len() == a.n()) + } + + let mut a: VecZnx<&mut [u8]> = self.to_mut(); + let a_size: usize = a.size(); + + { + let mut carry_i128: Vec = vec![0i128; a.n()]; + carry_i128.copy_from_slice(data); + + for j in (0..size).rev() { + for (x, a) in izip!(a.at_mut(col, j).iter_mut(), carry_i128.iter_mut()) { + let digit: i128 = get_digit_i128(base2k, *a); + let carry: i128 = get_carry_i128(base2k, *a, digit); + *x = digit as i64; + *a = carry; + } + } + } + + for j in size..a_size { + ZnxRef::znx_zero(a.at_mut(col, j)); + } + + let mut carry: Vec = vec![0i64; a.n()]; + let k_rem: usize = (base2k - (k % base2k)) % base2k; + + for j in (0..size).rev() { + if j == a_size - 1 { + ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry); + } else if j == 0 { + ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry); + } else { + ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry); + } + } + } + + pub fn encode_coeff_i64(&mut self, base2k: usize, col: usize, k: usize, idx: usize, data: i64) { + let size: usize = k.div_ceil(base2k); #[cfg(debug_assertions)] { @@ -69,46 +112,42 @@ impl VecZnx { assert!(idx < a.n()); assert!( size <= a.size(), - "invalid argument k.div_ceil(basek)={} > a.size()={}", + "invalid argument k.div_ceil(base2k)={} > a.size()={}", size, a.size() ); assert!(col < a.cols()); } - let k_rem: usize = basek - (k % basek); let mut a: VecZnx<&mut [u8]> = self.to_mut(); - (0..a.size()).for_each(|j| a.at_mut(col, j)[idx] = 0); + let a_size = a.size(); - // If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy - // values on the last limb. - // Else we decompose values base2k. - if log_max + k_rem < 63 || k_rem == basek { - a.at_mut(col, size - 1)[idx] = data; - } else { - let mask: i64 = (1 << basek) - 1; - let steps: usize = size.min(log_max.div_ceil(basek)); - (size - steps..size) - .rev() - .enumerate() - .for_each(|(j, j_rev)| { - a.at_mut(col, j_rev)[idx] = (data >> (j * basek)) & mask; - }) + for j in 0..a_size { + a.at_mut(col, j)[idx] = 0 } - // Case where prec % k != 0. - if k_rem != basek { - let steps: usize = size.min(log_max.div_ceil(basek)); - (size - steps..size).rev().for_each(|j| { - a.at_mut(col, j)[idx] <<= k_rem; - }) + a.at_mut(col, size - 1)[idx] = data; + + let mut carry: Vec = vec![0i64; 1]; + let k_rem: usize = (base2k - (k % base2k)) % base2k; + + for j in (0..size).rev() { + let slice = &mut a.at_mut(col, j)[idx..idx + 1]; + + if j == size - 1 { + ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, slice, &mut carry); + } else if j == 0 { + ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, slice, &mut carry); + } else { + ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, slice, &mut carry); + } } } } impl VecZnx { - pub fn decode_vec_i64(&self, basek: usize, col: usize, k: usize, data: &mut [i64]) { - let size: usize = k.div_ceil(basek); + pub fn decode_vec_i64(&self, base2k: usize, col: usize, k: usize, data: &mut [i64]) { + let size: usize = k.div_ceil(base2k); #[cfg(debug_assertions)] { let a: VecZnx<&[u8]> = self.to_ref(); @@ -123,26 +162,26 @@ impl VecZnx { let a: VecZnx<&[u8]> = self.to_ref(); data.copy_from_slice(a.at(col, 0)); - let rem: usize = basek - (k % basek); - if k < basek { + let rem: usize = base2k - (k % base2k); + if k < base2k { data.iter_mut().for_each(|x| *x >>= rem); } else { (1..size).for_each(|i| { - if i == size - 1 && rem != basek { - let k_rem: usize = basek - rem; + if i == size - 1 && rem != base2k { + let k_rem: usize = (base2k - rem) % base2k; izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << basek) + x; + *y = (*y << base2k) + x; }); } }) } } - pub fn decode_coeff_i64(&self, basek: usize, col: usize, k: usize, idx: usize) -> i64 { + pub fn decode_coeff_i64(&self, base2k: usize, col: usize, k: usize, idx: usize) -> i64 { #[cfg(debug_assertions)] { let a: VecZnx<&[u8]> = self.to_ref(); @@ -151,22 +190,22 @@ impl VecZnx { } let a: VecZnx<&[u8]> = self.to_ref(); - let size: usize = k.div_ceil(basek); + let size: usize = k.div_ceil(base2k); let mut res: i64 = 0; - let rem: usize = basek - (k % basek); + let rem: usize = base2k - (k % base2k); (0..size).for_each(|j| { let x: i64 = a.at(col, j)[idx]; - if j == size - 1 && rem != basek { - let k_rem: usize = basek - rem; + if j == size - 1 && rem != base2k { + let k_rem: usize = (base2k - rem) % base2k; res = (res << k_rem) + (x >> rem); } else { - res = (res << basek) + x; + res = (res << base2k) + x; } }); res } - pub fn decode_vec_float(&self, basek: usize, col: usize, data: &mut [Float]) { + pub fn decode_vec_float(&self, base2k: usize, col: usize, data: &mut [Float]) { #[cfg(debug_assertions)] { let a: VecZnx<&[u8]> = self.to_ref(); @@ -181,12 +220,12 @@ impl VecZnx { let a: VecZnx<&[u8]> = self.to_ref(); let size: usize = a.size(); - let prec: u32 = (basek * size) as u32; + let prec: u32 = (base2k * size) as u32; - // 2^{basek} - let base = Float::with_val(prec, (1u64 << basek) as f64); + // 2^{base2k} + let base: Float = Float::with_val(prec, (1u64 << base2k) as f64); - // y[i] = sum x[j][i] * 2^{-basek*j} + // y[i] = sum x[j][i] * 2^{-base2k*j} (0..size).for_each(|i| { if i == 0 { izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { @@ -204,78 +243,74 @@ impl VecZnx { } impl Zn { - pub fn encode_i64(&mut self, basek: usize, k: usize, data: i64, log_max: usize) { - let size: usize = k.div_ceil(basek); + pub fn encode_i64(&mut self, base2k: usize, k: usize, data: i64) { + let size: usize = k.div_ceil(base2k); #[cfg(debug_assertions)] { let a: Zn<&mut [u8]> = self.to_mut(); assert!( size <= a.size(), - "invalid argument k.div_ceil(basek)={} > a.size()={}", + "invalid argument k.div_ceil(base2k)={} > a.size()={}", size, a.size() ); } - let k_rem: usize = basek - (k % basek); let mut a: Zn<&mut [u8]> = self.to_mut(); - (0..a.size()).for_each(|j| a.at_mut(0, j)[0] = 0); + let a_size = a.size(); - // If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy - // values on the last limb. - // Else we decompose values base2k. - if log_max + k_rem < 63 || k_rem == basek { - a.at_mut(0, size - 1)[0] = data; - } else { - let mask: i64 = (1 << basek) - 1; - let steps: usize = size.min(log_max.div_ceil(basek)); - (size - steps..size) - .rev() - .enumerate() - .for_each(|(j, j_rev)| { - a.at_mut(0, j_rev)[0] = (data >> (j * basek)) & mask; - }) + for j in 0..a_size { + a.at_mut(0, j)[0] = 0 } - // Case where prec % k != 0. - if k_rem != basek { - let steps: usize = size.min(log_max.div_ceil(basek)); - (size - steps..size).rev().for_each(|j| { - a.at_mut(0, j)[0] <<= k_rem; - }) + a.at_mut(0, size - 1)[0] = data; + + let mut carry: Vec = vec![0i64; 1]; + let k_rem: usize = (base2k - (k % base2k)) % base2k; + + for j in (0..size).rev() { + let slice = &mut a.at_mut(0, j)[..1]; + + if j == size - 1 { + ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, slice, &mut carry); + } else if j == 0 { + ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, slice, &mut carry); + } else { + ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, slice, &mut carry); + } } } } impl Zn { - pub fn decode_i64(&self, basek: usize, k: usize) -> i64 { + pub fn decode_i64(&self, base2k: usize, k: usize) -> i64 { let a: Zn<&[u8]> = self.to_ref(); - let size: usize = k.div_ceil(basek); + let size: usize = k.div_ceil(base2k); let mut res: i64 = 0; - let rem: usize = basek - (k % basek); + let rem: usize = base2k - (k % base2k); (0..size).for_each(|j| { let x: i64 = a.at(0, j)[0]; - if j == size - 1 && rem != basek { - let k_rem: usize = basek - rem; + if j == size - 1 && rem != base2k { + let k_rem: usize = (base2k - rem) % base2k; res = (res << k_rem) + (x >> rem); } else { - res = (res << basek) + x; + res = (res << base2k) + x; } }); res } - pub fn decode_float(&self, basek: usize) -> Float { + pub fn decode_float(&self, base2k: usize) -> Float { let a: Zn<&[u8]> = self.to_ref(); let size: usize = a.size(); - let prec: u32 = (basek * size) as u32; + let prec: u32 = (base2k * size) as u32; - // 2^{basek} - let base: Float = Float::with_val(prec, (1 << basek) as f64); - let mut res: Float = Float::with_val(prec, (1 << basek) as f64); + // 2^{base2k} + let base: Float = Float::with_val(prec, (1 << base2k) as f64); + let mut res: Float = Float::with_val(prec, (1 << base2k) as f64); - // y[i] = sum x[j][i] * 2^{-basek*j} + // y[i] = sum x[j][i] * 2^{-base2k*j} (0..size).for_each(|i| { if i == 0 { res.assign(a.at(0, size - i - 1)[0]); diff --git a/poulpy-hal/src/layouts/mat_znx.rs b/poulpy-hal/src/layouts/mat_znx.rs index 59a1c4c..01be1a1 100644 --- a/poulpy-hal/src/layouts/mat_znx.rs +++ b/poulpy-hal/src/layouts/mat_znx.rs @@ -1,7 +1,7 @@ use crate::{ alloc_aligned, layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, + Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, @@ -54,7 +54,7 @@ impl ToOwnedDeep for MatZnx { impl fmt::Debug for MatZnx { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -211,17 +211,6 @@ impl FillUniform for MatZnx { } } -impl Reset for MatZnx { - fn reset(&mut self) { - self.zero(); - self.n = 0; - self.size = 0; - self.rows = 0; - self.cols_in = 0; - self.cols_out = 0; - } -} - pub type MatZnxOwned = MatZnx>; pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>; pub type MatZnxRef<'a> = MatZnx<&'a [u8]>; @@ -316,9 +305,9 @@ impl fmt::Display for MatZnx { )?; for row_i in 0..self.rows { - writeln!(f, "Row {}:", row_i)?; + writeln!(f, "Row {row_i}:")?; for col_i in 0..self.cols_in { - writeln!(f, "cols_in {}:", col_i)?; + writeln!(f, "cols_in {col_i}:")?; writeln!(f, "{}:", self.at(row_i, col_i))?; } } diff --git a/poulpy-hal/src/layouts/mod.rs b/poulpy-hal/src/layouts/mod.rs index d2b8fe5..cce4df4 100644 --- a/poulpy-hal/src/layouts/mod.rs +++ b/poulpy-hal/src/layouts/mod.rs @@ -26,7 +26,7 @@ pub use vmp_pmat::*; pub use zn::*; pub use znx_base::*; -pub trait Data = PartialEq + Eq + Sized; +pub trait Data = PartialEq + Eq + Sized + Default; pub trait DataRef = Data + AsRef<[u8]>; pub trait DataMut = DataRef + AsMut<[u8]>; diff --git a/poulpy-hal/src/layouts/scalar_znx.rs b/poulpy-hal/src/layouts/scalar_znx.rs index 6baf66f..296071b 100644 --- a/poulpy-hal/src/layouts/scalar_znx.rs +++ b/poulpy-hal/src/layouts/scalar_znx.rs @@ -7,7 +7,7 @@ use rand_distr::{Distribution, weighted::WeightedIndex}; use crate::{ alloc_aligned, layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, + Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, @@ -173,14 +173,6 @@ impl FillUniform for ScalarZnx { } } -impl Reset for ScalarZnx { - fn reset(&mut self) { - self.zero(); - self.n = 0; - self.cols = 0; - } -} - pub type ScalarZnxOwned = ScalarZnx>; impl ScalarZnx { diff --git a/poulpy-hal/src/layouts/stats.rs b/poulpy-hal/src/layouts/stats.rs index 05dd087..d40ffa5 100644 --- a/poulpy-hal/src/layouts/stats.rs +++ b/poulpy-hal/src/layouts/stats.rs @@ -7,10 +7,10 @@ use rug::{ use crate::layouts::{Backend, DataRef, VecZnx, VecZnxBig, VecZnxBigToRef, ZnxInfos}; impl VecZnx { - pub fn std(&self, basek: usize, col: usize) -> f64 { - let prec: u32 = (self.size() * basek) as u32; + pub fn std(&self, base2k: usize, col: usize) -> f64 { + let prec: u32 = (self.size() * base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); - self.decode_vec_float(basek, col, &mut data); + self.decode_vec_float(base2k, col, &mut data); // std = sqrt(sum((xi - avg)^2) / n) let mut avg: Float = Float::with_val(prec, 0); data.iter().for_each(|x| { @@ -29,7 +29,7 @@ impl VecZnx { } impl> VecZnxBig { - pub fn std(&self, basek: usize, col: usize) -> f64 { + pub fn std(&self, base2k: usize, col: usize) -> f64 { let self_ref: VecZnxBig<&[u8], B> = self.to_ref(); let znx: VecZnx<&[u8]> = VecZnx { data: self_ref.data, @@ -38,6 +38,6 @@ impl> VecZnxBig { size: self_ref.size, max_size: self_ref.max_size, }; - znx.std(basek, col) + znx.std(base2k, col) } } diff --git a/poulpy-hal/src/layouts/svp_ppol.rs b/poulpy-hal/src/layouts/svp_ppol.rs index 428f055..50523fc 100644 --- a/poulpy-hal/src/layouts/svp_ppol.rs +++ b/poulpy-hal/src/layouts/svp_ppol.rs @@ -176,7 +176,7 @@ impl fmt::Display for SvpPPol { writeln!(f, "SvpPPol(n={}, cols={})", self.n, self.cols)?; for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; + writeln!(f, "Column {col}:")?; let coeffs = self.at(col, 0); write!(f, "[")?; @@ -187,7 +187,7 @@ impl fmt::Display for SvpPPol { if i > 0 { write!(f, ", ")?; } - write!(f, "{}", coeff)?; + write!(f, "{coeff}")?; } if coeffs.len() > max_show { diff --git a/poulpy-hal/src/layouts/vec_znx.rs b/poulpy-hal/src/layouts/vec_znx.rs index 0ff9454..d40ef4b 100644 --- a/poulpy-hal/src/layouts/vec_znx.rs +++ b/poulpy-hal/src/layouts/vec_znx.rs @@ -6,8 +6,8 @@ use std::{ use crate::{ alloc_aligned, layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, - ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos, + ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, }; @@ -25,6 +25,18 @@ pub struct VecZnx { pub max_size: usize, } +impl Default for VecZnx { + fn default() -> Self { + Self { + data: D::default(), + n: 0, + cols: 0, + size: 0, + max_size: 0, + } + } +} + impl DigestU64 for VecZnx { fn digest_u64(&self) -> u64 { let mut h: DefaultHasher = DefaultHasher::new(); @@ -52,7 +64,7 @@ impl ToOwnedDeep for VecZnx { impl fmt::Debug for VecZnx { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -162,10 +174,10 @@ impl fmt::Display for VecZnx { )?; for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; + writeln!(f, "Column {col}:")?; for size in 0..self.size { let coeffs = self.at(col, size); - write!(f, " Size {}: [", size)?; + write!(f, " Size {size}: [")?; let max_show = 100; let show_count = coeffs.len().min(max_show); @@ -174,7 +186,7 @@ impl fmt::Display for VecZnx { if i > 0 { write!(f, ", ")?; } - write!(f, "{}", coeff)?; + write!(f, "{coeff}")?; } if coeffs.len() > max_show { @@ -204,16 +216,6 @@ impl FillUniform for VecZnx { } } -impl Reset for VecZnx { - fn reset(&mut self) { - self.zero(); - self.n = 0; - self.cols = 0; - self.size = 0; - self.max_size = 0; - } -} - pub type VecZnxOwned = VecZnx>; pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; diff --git a/poulpy-hal/src/layouts/vec_znx_big.rs b/poulpy-hal/src/layouts/vec_znx_big.rs index ee5e919..c50cf66 100644 --- a/poulpy-hal/src/layouts/vec_znx_big.rs +++ b/poulpy-hal/src/layouts/vec_znx_big.rs @@ -179,10 +179,10 @@ impl fmt::Display for VecZnxBig { )?; for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; + writeln!(f, "Column {col}:")?; for size in 0..self.size { let coeffs = self.at(col, size); - write!(f, " Size {}: [", size)?; + write!(f, " Size {size}: [")?; let max_show = 100; let show_count = coeffs.len().min(max_show); @@ -191,7 +191,7 @@ impl fmt::Display for VecZnxBig { if i > 0 { write!(f, ", ")?; } - write!(f, "{}", coeff)?; + write!(f, "{coeff}")?; } if coeffs.len() > max_show { diff --git a/poulpy-hal/src/layouts/vec_znx_dft.rs b/poulpy-hal/src/layouts/vec_znx_dft.rs index 027742c..3dc92d5 100644 --- a/poulpy-hal/src/layouts/vec_znx_dft.rs +++ b/poulpy-hal/src/layouts/vec_znx_dft.rs @@ -199,10 +199,10 @@ impl fmt::Display for VecZnxDft { )?; for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; + writeln!(f, "Column {col}:")?; for size in 0..self.size { let coeffs = self.at(col, size); - write!(f, " Size {}: [", size)?; + write!(f, " Size {size}: [")?; let max_show = 100; let show_count = coeffs.len().min(max_show); @@ -211,7 +211,7 @@ impl fmt::Display for VecZnxDft { if i > 0 { write!(f, ", ")?; } - write!(f, "{}", coeff)?; + write!(f, "{coeff}")?; } if coeffs.len() > max_show { diff --git a/poulpy-hal/src/layouts/zn.rs b/poulpy-hal/src/layouts/zn.rs index 7807034..40f5622 100644 --- a/poulpy-hal/src/layouts/zn.rs +++ b/poulpy-hal/src/layouts/zn.rs @@ -6,8 +6,8 @@ use std::{ use crate::{ alloc_aligned, layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, - ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos, + ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, }; @@ -52,7 +52,7 @@ impl ToOwnedDeep for Zn { impl fmt::Debug for Zn { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -162,10 +162,10 @@ impl fmt::Display for Zn { )?; for col in 0..self.cols { - writeln!(f, "Column {}:", col)?; + writeln!(f, "Column {col}:")?; for size in 0..self.size { let coeffs = self.at(col, size); - write!(f, " Size {}: [", size)?; + write!(f, " Size {size}: [")?; let max_show = 100; let show_count = coeffs.len().min(max_show); @@ -174,7 +174,7 @@ impl fmt::Display for Zn { if i > 0 { write!(f, ", ")?; } - write!(f, "{}", coeff)?; + write!(f, "{coeff}")?; } if coeffs.len() > max_show { @@ -204,16 +204,6 @@ impl FillUniform for Zn { } } -impl Reset for Zn { - fn reset(&mut self) { - self.zero(); - self.n = 0; - self.cols = 0; - self.size = 0; - self.max_size = 0; - } -} - pub type ZnOwned = Zn>; pub type ZnMut<'a> = Zn<&'a mut [u8]>; pub type ZnRef<'a> = Zn<&'a [u8]>; diff --git a/poulpy-hal/src/layouts/znx_base.rs b/poulpy-hal/src/layouts/znx_base.rs index 6173daf..7ca75b2 100644 --- a/poulpy-hal/src/layouts/znx_base.rs +++ b/poulpy-hal/src/layouts/znx_base.rs @@ -119,7 +119,3 @@ where pub trait FillUniform { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source); } - -pub trait Reset { - fn reset(&mut self); -} diff --git a/poulpy-hal/src/lib.rs b/poulpy-hal/src/lib.rs index ca09126..92d874f 100644 --- a/poulpy-hal/src/lib.rs +++ b/poulpy-hal/src/lib.rs @@ -56,15 +56,12 @@ pub fn cast_mut(data: &[T]) -> &mut [V] { fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { assert!( align.is_power_of_two(), - "Alignment must be a power of two but is {}", - align + "Alignment must be a power of two but is {align}" ); assert_eq!( (size * size_of::()) % align, 0, - "size={} must be a multiple of align={}", - size, - align + "size={size} must be a multiple of align={align}" ); unsafe { let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment"); @@ -74,9 +71,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { } assert!( is_aligned_custom(ptr, align), - "Memory allocation at {:p} is not aligned to {} bytes", - ptr, - align + "Memory allocation at {ptr:p} is not aligned to {align} bytes" ); // Init allocated memory to zero std::ptr::write_bytes(ptr, 0, size); @@ -89,16 +84,14 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { assert!( align.is_power_of_two(), - "Alignment must be a power of two but is {}", - align + "Alignment must be a power of two but is {align}" ); assert_eq!( (size * size_of::()) % align, 0, - "size*size_of::()={} must be a multiple of align={}", + "size*size_of::()={} must be a multiple of align={align}", size * size_of::(), - align ); let mut vec_u8: Vec = alloc_aligned_custom_u8(size_of::() * size, align); diff --git a/poulpy-hal/src/oep/module.rs b/poulpy-hal/src/oep/module.rs index 760a872..bbf57f8 100644 --- a/poulpy-hal/src/oep/module.rs +++ b/poulpy-hal/src/oep/module.rs @@ -1,8 +1,8 @@ use crate::layouts::{Backend, Module}; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/module.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/module.rs) reference implementation. +/// * See [crate::api::ModuleNew] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ModuleNewImpl { fn new_impl(n: u64) -> Module; diff --git a/poulpy-hal/src/oep/scratch.rs b/poulpy-hal/src/oep/scratch.rs index 9a06b4e..51a9c56 100644 --- a/poulpy-hal/src/oep/scratch.rs +++ b/poulpy-hal/src/oep/scratch.rs @@ -1,74 +1,72 @@ -use crate::layouts::{ - Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos, -}; +use crate::layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::ScratchOwnedAlloc] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ScratchOwnedAllocImpl { fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::ScratchOwnedBorrow] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ScratchOwnedBorrowImpl { fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned) -> &mut Scratch; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::ScratchFromBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ScratchFromBytesImpl { fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::ScratchAvailable] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ScratchAvailableImpl { fn scratch_available_impl(scratch: &Scratch) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::ScratchOwnedAlloc] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeSliceImpl { fn take_slice_impl(scratch: &mut Scratch, len: usize) -> (&mut [T], &mut Scratch); } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::TakeScalarZnx] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeScalarZnxImpl { fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch); } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::TakeSvpPPol] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeSvpPPolImpl { fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch); } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::TakeVecZnx] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeVecZnxImpl { fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch); } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::TakeVecZnxSlice] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeVecZnxSliceImpl { fn take_vec_znx_slice_impl( @@ -81,8 +79,8 @@ pub unsafe trait TakeVecZnxSliceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::TakeVecZnxBig] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeVecZnxBigImpl { fn take_vec_znx_big_impl( @@ -94,8 +92,8 @@ pub unsafe trait TakeVecZnxBigImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::TakeVecZnxDft] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeVecZnxDftImpl { fn take_vec_znx_dft_impl( @@ -107,8 +105,8 @@ pub unsafe trait TakeVecZnxDftImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::TakeVecZnxDftSlice] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeVecZnxDftSliceImpl { fn take_vec_znx_dft_slice_impl( @@ -121,8 +119,8 @@ pub unsafe trait TakeVecZnxDftSliceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::TakeVmpPMat] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeVmpPMatImpl { fn take_vmp_pmat_impl( @@ -136,8 +134,8 @@ pub unsafe trait TakeVmpPMatImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation. +/// * See [crate::api::TakeMatZnx] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait TakeMatZnxImpl { fn take_mat_znx_impl( @@ -149,110 +147,3 @@ pub unsafe trait TakeMatZnxImpl { size: usize, ) -> (MatZnx<&mut [u8]>, &mut Scratch); } - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub trait TakeLikeImpl<'a, B: Backend, T> { - type Output; - fn take_like_impl(scratch: &'a mut Scratch, template: &T) -> (Self::Output, &'a mut Scratch); -} - -impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VmpPMat> for B -where - B: TakeVmpPMatImpl, - D: DataRef, -{ - type Output = VmpPMat<&'a mut [u8], B>; - - fn take_like_impl(scratch: &'a mut Scratch, template: &VmpPMat) -> (Self::Output, &'a mut Scratch) { - B::take_vmp_pmat_impl( - scratch, - template.n(), - template.rows(), - template.cols_in(), - template.cols_out(), - template.size(), - ) - } -} - -impl<'a, B: Backend, D> TakeLikeImpl<'a, B, MatZnx> for B -where - B: TakeMatZnxImpl, - D: DataRef, -{ - type Output = MatZnx<&'a mut [u8]>; - - fn take_like_impl(scratch: &'a mut Scratch, template: &MatZnx) -> (Self::Output, &'a mut Scratch) { - B::take_mat_znx_impl( - scratch, - template.n(), - template.rows(), - template.cols_in(), - template.cols_out(), - template.size(), - ) - } -} - -impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxDft> for B -where - B: TakeVecZnxDftImpl, - D: DataRef, -{ - type Output = VecZnxDft<&'a mut [u8], B>; - - fn take_like_impl(scratch: &'a mut Scratch, template: &VecZnxDft) -> (Self::Output, &'a mut Scratch) { - B::take_vec_znx_dft_impl(scratch, template.n(), template.cols(), template.size()) - } -} - -impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxBig> for B -where - B: TakeVecZnxBigImpl, - D: DataRef, -{ - type Output = VecZnxBig<&'a mut [u8], B>; - - fn take_like_impl(scratch: &'a mut Scratch, template: &VecZnxBig) -> (Self::Output, &'a mut Scratch) { - B::take_vec_znx_big_impl(scratch, template.n(), template.cols(), template.size()) - } -} - -impl<'a, B: Backend, D> TakeLikeImpl<'a, B, SvpPPol> for B -where - B: TakeSvpPPolImpl, - D: DataRef, -{ - type Output = SvpPPol<&'a mut [u8], B>; - - fn take_like_impl(scratch: &'a mut Scratch, template: &SvpPPol) -> (Self::Output, &'a mut Scratch) { - B::take_svp_ppol_impl(scratch, template.n(), template.cols()) - } -} - -impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnx> for B -where - B: TakeVecZnxImpl, - D: DataRef, -{ - type Output = VecZnx<&'a mut [u8]>; - - fn take_like_impl(scratch: &'a mut Scratch, template: &VecZnx) -> (Self::Output, &'a mut Scratch) { - B::take_vec_znx_impl(scratch, template.n(), template.cols(), template.size()) - } -} - -impl<'a, B: Backend, D> TakeLikeImpl<'a, B, ScalarZnx> for B -where - B: TakeScalarZnxImpl, - D: DataRef, -{ - type Output = ScalarZnx<&'a mut [u8]>; - - fn take_like_impl(scratch: &'a mut Scratch, template: &ScalarZnx) -> (Self::Output, &'a mut Scratch) { - B::take_scalar_znx_impl(scratch, template.n(), template.cols()) - } -} diff --git a/poulpy-hal/src/oep/svp_ppol.rs b/poulpy-hal/src/oep/svp_ppol.rs index b50208a..6550b6f 100644 --- a/poulpy-hal/src/oep/svp_ppol.rs +++ b/poulpy-hal/src/oep/svp_ppol.rs @@ -3,32 +3,32 @@ use crate::layouts::{ }; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation. +/// * See [crate::api::SvpPPolFromBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpPPolFromBytesImpl { fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation. +/// * See [crate::api::SvpPPolAlloc] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpPPolAllocImpl { fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation. +/// * See [crate::api::SvpPPolAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpPPolAllocBytesImpl { fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation. +/// * See [crate::api::SvpPrepare] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpPrepareImpl { fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -38,8 +38,8 @@ pub unsafe trait SvpPrepareImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation. +/// * See [crate::api::SvpApplyDft] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpApplyDftImpl { fn svp_apply_dft_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) @@ -50,8 +50,8 @@ pub unsafe trait SvpApplyDftImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation. +/// * See [crate::api::SvpApplyDftToDft] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpApplyDftToDftImpl { fn svp_apply_dft_to_dft_impl( @@ -69,8 +69,8 @@ pub unsafe trait SvpApplyDftToDftImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation. +/// * See [crate::api::SvpApplyDftToDftAdd] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpApplyDftToDftAddImpl { fn svp_apply_dft_to_dft_add_impl( @@ -88,8 +88,8 @@ pub unsafe trait SvpApplyDftToDftAddImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation. +/// * See [crate::api::SvpApplyDftToDftInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait SvpApplyDftToDftInplaceImpl: Backend { fn svp_apply_dft_to_dft_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) diff --git a/poulpy-hal/src/oep/vec_znx.rs b/poulpy-hal/src/oep/vec_znx.rs index 268e46b..380253e 100644 --- a/poulpy-hal/src/oep/vec_znx.rs +++ b/poulpy-hal/src/oep/vec_znx.rs @@ -4,7 +4,7 @@ use crate::{ }; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNormalizeTmpBytesImpl { @@ -12,15 +12,17 @@ pub unsafe trait VecZnxNormalizeTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxNormalize] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNormalizeImpl { + #[allow(clippy::too_many_arguments)] fn vec_znx_normalize_impl( module: &Module, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -30,17 +32,17 @@ pub unsafe trait VecZnxNormalizeImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxNormalizeInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNormalizeInplaceImpl { - fn vec_znx_normalize_inplace_impl(module: &Module, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_normalize_inplace_impl(module: &Module, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxAdd] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddImpl { @@ -52,7 +54,7 @@ pub unsafe trait VecZnxAddImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxAddInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddInplaceImpl { @@ -63,7 +65,7 @@ pub unsafe trait VecZnxAddInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxAddScalar] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddScalarImpl { @@ -85,7 +87,7 @@ pub unsafe trait VecZnxAddScalarImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxAddScalarInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddScalarInplaceImpl { @@ -102,7 +104,7 @@ pub unsafe trait VecZnxAddScalarInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxSub] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSubImpl { @@ -114,29 +116,29 @@ pub unsafe trait VecZnxSubImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. -/// * See [crate::api::VecZnxSubABInplace] for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. +/// * See [crate::api::VecZnxSubInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxSubABInplaceImpl { - fn vec_znx_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub unsafe trait VecZnxSubInplaceImpl { + fn vec_znx_sub_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. -/// * See [crate::api::VecZnxSubBAInplace] for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. +/// * See [crate::api::VecZnxSubNegateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxSubBAInplaceImpl { - fn vec_znx_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub unsafe trait VecZnxSubNegateInplaceImpl { + fn vec_znx_sub_negate_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxAddScalar] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSubScalarImpl { @@ -158,7 +160,7 @@ pub unsafe trait VecZnxSubScalarImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxSubScalarInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSubScalarInplaceImpl { @@ -175,7 +177,7 @@ pub unsafe trait VecZnxSubScalarInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxNegate] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNegateImpl { @@ -186,7 +188,7 @@ pub unsafe trait VecZnxNegateImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxNegateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNegateInplaceImpl { @@ -196,7 +198,7 @@ pub unsafe trait VecZnxNegateInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_tmp_bytes] for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxRshTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRshTmpBytesImpl { @@ -204,14 +206,14 @@ pub unsafe trait VecZnxRshTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_inplace] for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxRsh] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRshImpl { #[allow(clippy::too_many_arguments)] - fn vec_znx_rsh_inplace_impl( + fn vec_znx_rsh_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -224,7 +226,7 @@ pub unsafe trait VecZnxRshImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_tmp_bytes] for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxLshTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxLshTmpBytesImpl { @@ -232,14 +234,14 @@ pub unsafe trait VecZnxLshTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_inplace] for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxLsh] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxLshImpl { #[allow(clippy::too_many_arguments)] - fn vec_znx_lsh_inplace_impl( + fn vec_znx_lsh_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -252,13 +254,13 @@ pub unsafe trait VecZnxLshImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxRshInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRshInplaceImpl { fn vec_znx_rsh_inplace_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -268,13 +270,13 @@ pub unsafe trait VecZnxRshInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::cpu_spqlios::vec_znx::vec_znx_lsh_inplace_ref] for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxLshInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxLshInplaceImpl { fn vec_znx_lsh_inplace_impl( module: &Module, - basek: usize, + base2k: usize, k: usize, res: &mut R, res_col: usize, @@ -284,7 +286,7 @@ pub unsafe trait VecZnxLshInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxRotate] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRotateImpl { @@ -295,7 +297,7 @@ pub unsafe trait VecZnxRotateImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO; +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxRotateInplaceTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRotateInplaceTmpBytesImpl { @@ -303,7 +305,7 @@ pub unsafe trait VecZnxRotateInplaceTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxRotateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRotateInplaceImpl { @@ -313,7 +315,7 @@ pub unsafe trait VecZnxRotateInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxAutomorphism] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAutomorphismImpl { @@ -324,7 +326,7 @@ pub unsafe trait VecZnxAutomorphismImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO; +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxAutomorphismInplaceTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAutomorphismInplaceTmpBytesImpl { @@ -332,7 +334,7 @@ pub unsafe trait VecZnxAutomorphismInplaceTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxAutomorphismInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAutomorphismInplaceImpl { @@ -342,7 +344,7 @@ pub unsafe trait VecZnxAutomorphismInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxMulXpMinusOne] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxMulXpMinusOneImpl { @@ -353,7 +355,7 @@ pub unsafe trait VecZnxMulXpMinusOneImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO; +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxMulXpMinusOneInplaceTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxMulXpMinusOneInplaceTmpBytesImpl { @@ -361,7 +363,7 @@ pub unsafe trait VecZnxMulXpMinusOneInplaceTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxMulXpMinusOneInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxMulXpMinusOneInplaceImpl { @@ -376,7 +378,7 @@ pub unsafe trait VecZnxMulXpMinusOneInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO; +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxSplitRingTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSplitRingTmpBytesImpl { @@ -401,7 +403,7 @@ pub unsafe trait VecZnxSplitRingImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO; +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxMergeRingsTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxMergeRingsTmpBytesImpl { @@ -426,7 +428,7 @@ pub unsafe trait VecZnxMergeRingsImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxSwithcDegree] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxSwitchRingImpl { @@ -440,7 +442,7 @@ pub unsafe trait VecZnxSwitchRingImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::cpu_spqlios::vec_znx::vec_znx_copy_ref] for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxCopy] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxCopyImpl { @@ -451,22 +453,24 @@ pub unsafe trait VecZnxCopyImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxFillUniform] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxFillUniformImpl { - fn vec_znx_fill_uniform_impl(module: &Module, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn vec_znx_fill_uniform_impl(module: &Module, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut; } #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxFillNormal] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxFillNormalImpl { fn vec_znx_fill_normal_impl( module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -479,12 +483,13 @@ pub unsafe trait VecZnxFillNormalImpl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxAddNormal] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAddNormalImpl { fn vec_znx_add_normal_impl( module: &Module, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, diff --git a/poulpy-hal/src/oep/vec_znx_big.rs b/poulpy-hal/src/oep/vec_znx_big.rs index 8398983..b2bd4c8 100644 --- a/poulpy-hal/src/oep/vec_znx_big.rs +++ b/poulpy-hal/src/oep/vec_znx_big.rs @@ -4,8 +4,8 @@ use crate::{ }; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigFromSmall] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigFromSmallImpl { fn vec_znx_big_from_small_impl(res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -15,24 +15,24 @@ pub unsafe trait VecZnxBigFromSmallImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAlloc] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAllocImpl { fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigFromBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigFromBytesImpl { fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAllocBytesImpl { fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; @@ -40,13 +40,13 @@ pub unsafe trait VecZnxBigAllocBytesImpl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAddNormal] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAddNormalImpl { fn add_normal_impl>( module: &Module, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, k: usize, @@ -57,8 +57,8 @@ pub unsafe trait VecZnxBigAddNormalImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAdd] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAddImpl { fn vec_znx_big_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) @@ -69,8 +69,8 @@ pub unsafe trait VecZnxBigAddImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAddInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAddInplaceImpl { fn vec_znx_big_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -80,8 +80,8 @@ pub unsafe trait VecZnxBigAddInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAddSmall] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAddSmallImpl { fn vec_znx_big_add_small_impl( @@ -99,8 +99,8 @@ pub unsafe trait VecZnxBigAddSmallImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAddSmallInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAddSmallInplaceImpl { fn vec_znx_big_add_small_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -110,8 +110,8 @@ pub unsafe trait VecZnxBigAddSmallInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigSub] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigSubImpl { fn vec_znx_big_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) @@ -122,30 +122,30 @@ pub unsafe trait VecZnxBigSubImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigSubInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxBigSubABInplaceImpl { - fn vec_znx_big_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub unsafe trait VecZnxBigSubInplaceImpl { + fn vec_znx_big_sub_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigSubNegateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxBigSubBAInplaceImpl { - fn vec_znx_big_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub unsafe trait VecZnxBigSubNegateInplaceImpl { + fn vec_znx_big_sub_negate_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigSubSmallA] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigSubSmallAImpl { fn vec_znx_big_sub_small_a_impl( @@ -163,19 +163,19 @@ pub unsafe trait VecZnxBigSubSmallAImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigSubSmallInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxBigSubSmallAInplaceImpl { - fn vec_znx_big_sub_small_a_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub unsafe trait VecZnxBigSubSmallInplaceImpl { + fn vec_znx_big_sub_small_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigSubSmallB] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigSubSmallBImpl { fn vec_znx_big_sub_small_b_impl( @@ -193,19 +193,19 @@ pub unsafe trait VecZnxBigSubSmallBImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigSubSmallNegateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxBigSubSmallBInplaceImpl { - fn vec_znx_big_sub_small_b_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub unsafe trait VecZnxBigSubSmallNegateInplaceImpl { + fn vec_znx_big_sub_small_negate_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxToRef; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigNegate] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigNegateImpl { fn vec_znx_big_negate_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -215,8 +215,8 @@ pub unsafe trait VecZnxBigNegateImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigNegateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigNegateInplaceImpl { fn vec_znx_big_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) @@ -225,23 +225,25 @@ pub unsafe trait VecZnxBigNegateInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigNormalizeTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigNormalizeTmpBytesImpl { fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module) -> usize; } +#[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigNormalize] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigNormalizeImpl { fn vec_znx_big_normalize_impl( module: &Module, - basek: usize, + res_basek: usize, res: &mut R, res_col: usize, + a_basek: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -251,8 +253,8 @@ pub unsafe trait VecZnxBigNormalizeImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAutomorphism] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAutomorphismImpl { fn vec_znx_big_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -262,16 +264,16 @@ pub unsafe trait VecZnxBigAutomorphismImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAutomorphismInplaceTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAutomorphismInplaceTmpBytesImpl { fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation. +/// * See [crate::api::VecZnxBigAutomorphismInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAutomorphismInplaceImpl { fn vec_znx_big_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch) diff --git a/poulpy-hal/src/oep/vec_znx_dft.rs b/poulpy-hal/src/oep/vec_znx_dft.rs index c4fd46a..e5a2bcb 100644 --- a/poulpy-hal/src/oep/vec_znx_dft.rs +++ b/poulpy-hal/src/oep/vec_znx_dft.rs @@ -4,24 +4,24 @@ use crate::layouts::{ }; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftAlloc] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftAllocImpl { fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftFromBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftFromBytesImpl { fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftApply] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftApplyImpl { fn vec_znx_dft_apply_impl( @@ -38,24 +38,24 @@ pub unsafe trait VecZnxDftApplyImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftAllocBytesImpl { fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxIdftApplyTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxIdftApplyTmpBytesImpl { fn vec_znx_idft_apply_tmp_bytes_impl(module: &Module) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxIdftApply] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxIdftApplyImpl { fn vec_znx_idft_apply_impl( @@ -71,8 +71,8 @@ pub unsafe trait VecZnxIdftApplyImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxIdftApplyTmpA] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxIdftApplyTmpAImpl { fn vec_znx_idft_apply_tmpa_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) @@ -82,8 +82,8 @@ pub unsafe trait VecZnxIdftApplyTmpAImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxIdftApplyConsume] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxIdftApplyConsumeImpl { fn vec_znx_idft_apply_consume_impl(module: &Module, a: VecZnxDft) -> VecZnxBig @@ -92,8 +92,8 @@ pub unsafe trait VecZnxIdftApplyConsumeImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftAdd] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftAddImpl { fn vec_znx_dft_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) @@ -104,8 +104,8 @@ pub unsafe trait VecZnxDftAddImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftAddInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftAddInplaceImpl { fn vec_znx_dft_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) @@ -115,8 +115,8 @@ pub unsafe trait VecZnxDftAddInplaceImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftSub] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftSubImpl { fn vec_znx_dft_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) @@ -127,30 +127,30 @@ pub unsafe trait VecZnxDftSubImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftSubInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxDftSubABInplaceImpl { - fn vec_znx_dft_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub unsafe trait VecZnxDftSubInplaceImpl { + fn vec_znx_dft_sub_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftSubNegateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxDftSubBAInplaceImpl { - fn vec_znx_dft_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub unsafe trait VecZnxDftSubNegateInplaceImpl { + fn vec_znx_dft_sub_negate_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftCopy] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftCopyImpl { fn vec_znx_dft_copy_impl( @@ -167,8 +167,8 @@ pub unsafe trait VecZnxDftCopyImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation. +/// * See [crate::api::VecZnxDftZero] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftZeroImpl { fn vec_znx_dft_zero_impl(module: &Module, res: &mut R) diff --git a/poulpy-hal/src/oep/vmp_pmat.rs b/poulpy-hal/src/oep/vmp_pmat.rs index 74d8cd2..e399f00 100644 --- a/poulpy-hal/src/oep/vmp_pmat.rs +++ b/poulpy-hal/src/oep/vmp_pmat.rs @@ -3,24 +3,24 @@ use crate::layouts::{ }; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpPMatAlloc] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpPMatAllocImpl { fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpPMatAllocBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpPMatAllocBytesImpl { fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpPMatFromBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpPMatFromBytesImpl { fn vmp_pmat_from_bytes_impl( @@ -34,16 +34,16 @@ pub unsafe trait VmpPMatFromBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpPrepareTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpPrepareTmpBytesImpl { fn vmp_prepare_tmp_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpPrepare] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpPrepareImpl { fn vmp_prepare_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) @@ -54,8 +54,8 @@ pub unsafe trait VmpPrepareImpl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpApplyDftTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpApplyDftTmpBytesImpl { fn vmp_apply_dft_tmp_bytes_impl( @@ -70,8 +70,8 @@ pub unsafe trait VmpApplyDftTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpApplyDft] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpApplyDftImpl { fn vmp_apply_dft_impl(module: &Module, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) @@ -83,8 +83,8 @@ pub unsafe trait VmpApplyDftImpl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpApplyDftToDftTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpApplyDftToDftTmpBytesImpl { fn vmp_apply_dft_to_dft_tmp_bytes_impl( @@ -99,8 +99,8 @@ pub unsafe trait VmpApplyDftToDftTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpApplyDftToDft] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpApplyDftToDftImpl { fn vmp_apply_dft_to_dft_impl(module: &Module, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) @@ -112,8 +112,8 @@ pub unsafe trait VmpApplyDftToDftImpl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpApplyDftToDftAddTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpApplyDftToDftAddTmpBytesImpl { fn vmp_apply_dft_to_dft_add_tmp_bytes_impl( @@ -128,8 +128,8 @@ pub unsafe trait VmpApplyDftToDftAddTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. +/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation. +/// * See [crate::api::VmpApplyDftToDftAdd] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpApplyDftToDftAddImpl { // Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R. diff --git a/poulpy-hal/src/oep/zn.rs b/poulpy-hal/src/oep/zn.rs index 2a1122a..d2e03ad 100644 --- a/poulpy-hal/src/oep/zn.rs +++ b/poulpy-hal/src/oep/zn.rs @@ -4,7 +4,7 @@ use crate::{ }; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO +/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. /// * See [crate::api::ZnNormalizeTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ZnNormalizeTmpBytesImpl { @@ -12,32 +12,34 @@ pub unsafe trait ZnNormalizeTmpBytesImpl { } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [zn_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/zn64.c#L9) for reference code. +/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. /// * See [crate::api::ZnNormalizeInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ZnNormalizeInplaceImpl { - fn zn_normalize_inplace_impl(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) + fn zn_normalize_inplace_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) where R: ZnToMut; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. /// * See [crate::api::ZnFillUniform] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ZnFillUniformImpl { - fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut; } #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. /// * See [crate::api::ZnFillNormal] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ZnFillNormalImpl { fn zn_fill_normal_impl( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -50,12 +52,13 @@ pub unsafe trait ZnFillNormalImpl { #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. /// * See [crate::api::ZnAddNormal] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ZnAddNormalImpl { fn zn_add_normal_impl( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, diff --git a/poulpy-hal/src/reference/fft64/reim/fft_vec.rs b/poulpy-hal/src/reference/fft64/reim/fft_vec.rs index 63b4a80..54b4d7f 100644 --- a/poulpy-hal/src/reference/fft64/reim/fft_vec.rs +++ b/poulpy-hal/src/reference/fft64/reim/fft_vec.rs @@ -37,7 +37,7 @@ pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) { } #[inline(always)] -pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) { +pub fn reim_sub_inplace_ref(res: &mut [f64], a: &[f64]) { #[cfg(debug_assertions)] { assert_eq!(a.len(), res.len()); @@ -49,7 +49,7 @@ pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) { } #[inline(always)] -pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) { +pub fn reim_sub_negate_inplace_ref(res: &mut [f64], a: &[f64]) { #[cfg(debug_assertions)] { assert_eq!(a.len(), res.len()); diff --git a/poulpy-hal/src/reference/fft64/reim/mod.rs b/poulpy-hal/src/reference/fft64/reim/mod.rs index 28c90b1..7decf3a 100644 --- a/poulpy-hal/src/reference/fft64/reim/mod.rs +++ b/poulpy-hal/src/reference/fft64/reim/mod.rs @@ -91,12 +91,12 @@ pub trait ReimSub { fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]); } -pub trait ReimSubABInplace { - fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]); +pub trait ReimSubInplace { + fn reim_sub_inplace(res: &mut [f64], a: &[f64]); } -pub trait ReimSubBAInplace { - fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]); +pub trait ReimSubNegateInplace { + fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]); } pub trait ReimNegate { diff --git a/poulpy-hal/src/reference/fft64/reim/table_fft.rs b/poulpy-hal/src/reference/fft64/reim/table_fft.rs index 452678f..f76d128 100644 --- a/poulpy-hal/src/reference/fft64/reim/table_fft.rs +++ b/poulpy-hal/src/reference/fft64/reim/table_fft.rs @@ -22,7 +22,7 @@ pub struct ReimFFTTable { impl ReimFFTTable { pub fn new(m: usize) -> Self { - assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m); + assert!(m & (m - 1) == 0, "m must be a power of two but is {m}"); let mut omg: Vec = alloc_aligned::(2 * m); let quarter: R = R::from(1. / 4.).unwrap(); diff --git a/poulpy-hal/src/reference/fft64/reim/table_ifft.rs b/poulpy-hal/src/reference/fft64/reim/table_ifft.rs index 929b933..b1f0bb7 100644 --- a/poulpy-hal/src/reference/fft64/reim/table_ifft.rs +++ b/poulpy-hal/src/reference/fft64/reim/table_ifft.rs @@ -22,7 +22,7 @@ pub struct ReimIFFTTable { impl ReimIFFTTable { pub fn new(m: usize) -> Self { - assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m); + assert!(m & (m - 1) == 0, "m must be a power of two but is {m}"); let mut omg: Vec = alloc_aligned::(2 * m); let quarter: R = R::exp2(R::from(-2).unwrap()); diff --git a/poulpy-hal/src/reference/fft64/vec_znx_big.rs b/poulpy-hal/src/reference/fft64/vec_znx_big.rs index 7b9ceb5..2295ee7 100644 --- a/poulpy-hal/src/reference/fft64/vec_znx_big.rs +++ b/poulpy-hal/src/reference/fft64/vec_znx_big.rs @@ -9,12 +9,13 @@ use crate::{ reference::{ vec_znx::{ vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate, - vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, + vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, }, znx::{ - ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep, - ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, - ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref, + ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNegate, + ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, + ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero, + znx_add_normal_f64_ref, }, }, source::Source, @@ -230,20 +231,32 @@ where } pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize { - n * size_of::() + 2 * n * size_of::() } -pub fn vec_znx_big_normalize(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) -where +pub fn vec_znx_big_normalize( + res_basek: usize, + res: &mut R, + res_col: usize, + a_basek: usize, + a: &A, + a_col: usize, + carry: &mut [i64], +) where R: VecZnxToMut, A: VecZnxBigToRef, BE: Backend + + ZnxZero + + ZnxCopy + + ZnxAddInplace + + ZnxMulPowerOfTwoInplace + ZnxNormalizeFirstStepCarryOnly + ZnxNormalizeMiddleStepCarryOnly + ZnxNormalizeMiddleStep + ZnxNormalizeFinalStep + ZnxNormalizeFirstStep - + ZnxZero, + + ZnxExtractDigitAddMul + + ZnxNormalizeDigit, { let a: VecZnxBig<&[u8], _> = a.to_ref(); let a_vznx: VecZnx<&[u8]> = VecZnx { @@ -254,11 +267,11 @@ where max_size: a.max_size, }; - vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry); + vec_znx_normalize::<_, _, BE>(res_basek, res, res_col, a_basek, &a_vznx, a_col, carry); } pub fn vec_znx_big_add_normal_ref>( - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -275,8 +288,8 @@ pub fn vec_znx_big_add_normal_ref>( (bound.log2().ceil() as i64) ); - let limb: usize = k.div_ceil(basek) - 1; - let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + let limb: usize = k.div_ceil(base2k) - 1; + let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; znx_add_normal_f64_ref( res.at_mut(res_col, limb), sigma * scale, @@ -291,7 +304,7 @@ where B: Backend + VecZnxBigAllocBytesImpl, { let n: usize = module.n(); - let basek: usize = 17; + let base2k: usize = 17; let k: usize = 2 * 17; let size: usize = 5; let sigma: f64 = 3.2; @@ -303,15 +316,15 @@ where let sqrt2: f64 = SQRT_2; (0..cols).for_each(|col_i| { let mut a: VecZnxBig, B> = VecZnxBig::alloc(n, cols, size); - module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); - module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); + module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound); + module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { assert_eq!(a.at(col_j, limb_i), zero); }) } else { - let std: f64 = a.std(basek, col_i) * k_f64; + let std: f64 = a.std(base2k, col_i) * k_f64; assert!( (std - sigma * sqrt2).abs() < 0.1, "std={} ~!= {}", @@ -363,9 +376,9 @@ where } /// R <- A - B -pub fn vec_znx_big_sub_ab_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +pub fn vec_znx_big_sub_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where - BE: Backend + ZnxSubABInplace, + BE: Backend + ZnxSubInplace, R: VecZnxBigToMut, A: VecZnxBigToRef, { @@ -388,13 +401,13 @@ where max_size: a.max_size, }; - vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col); + vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col); } /// R <- B - A -pub fn vec_znx_big_sub_ba_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +pub fn vec_znx_big_sub_negate_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where - BE: Backend + ZnxSubBAInplace + ZnxNegateInplace, + BE: Backend + ZnxSubNegateInplace + ZnxNegateInplace, R: VecZnxBigToMut, A: VecZnxBigToRef, { @@ -417,7 +430,7 @@ where max_size: a.max_size, }; - vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col); + vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col); } /// R <- A - B @@ -483,7 +496,7 @@ where /// R <- R - A pub fn vec_znx_big_sub_small_a_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where - BE: Backend + ZnxSubABInplace, + BE: Backend + ZnxSubInplace, R: VecZnxBigToMut, A: VecZnxToRef, { @@ -497,13 +510,13 @@ where max_size: res.max_size, }; - vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col); + vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col); } /// R <- A - R pub fn vec_znx_big_sub_small_b_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where - BE: Backend + ZnxSubBAInplace + ZnxNegateInplace, + BE: Backend + ZnxSubNegateInplace + ZnxNegateInplace, R: VecZnxBigToMut, A: VecZnxToRef, { @@ -517,5 +530,5 @@ where max_size: res.max_size, }; - vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col); + vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col); } diff --git a/poulpy-hal/src/reference/fft64/vec_znx_dft.rs b/poulpy-hal/src/reference/fft64/vec_znx_dft.rs index 5abf0da..4bb086d 100644 --- a/poulpy-hal/src/reference/fft64/vec_znx_dft.rs +++ b/poulpy-hal/src/reference/fft64/vec_znx_dft.rs @@ -8,7 +8,7 @@ use crate::{ reference::{ fft64::reim::{ ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate, - ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero, + ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx, ReimToZnxInplace, ReimZero, }, znx::ZnxZero, }, @@ -308,9 +308,9 @@ where } } -pub fn vec_znx_dft_sub_ab_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +pub fn vec_znx_dft_sub_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where - BE: Backend + ReimSubABInplace, + BE: Backend + ReimSubInplace, R: VecZnxDftToMut, A: VecZnxDftToRef, { @@ -328,13 +328,13 @@ where let sum_size: usize = a_size.min(res_size); for j in 0..sum_size { - BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + BE::reim_sub_inplace(res.at_mut(res_col, j), a.at(a_col, j)); } } -pub fn vec_znx_dft_sub_ba_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +pub fn vec_znx_dft_sub_negate_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where - BE: Backend + ReimSubBAInplace + ReimNegateInplace, + BE: Backend + ReimSubNegateInplace + ReimNegateInplace, R: VecZnxDftToMut, A: VecZnxDftToRef, { @@ -352,7 +352,7 @@ where let sum_size: usize = a_size.min(res_size); for j in 0..sum_size { - BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + BE::reim_sub_negate_inplace(res.at_mut(res_col, j), a.at(a_col, j)); } for j in sum_size..res_size { diff --git a/poulpy-hal/src/reference/vec_znx/add.rs b/poulpy-hal/src/reference/vec_znx/add.rs index 56c9eb9..3fe5f91 100644 --- a/poulpy-hal/src/reference/vec_znx/add.rs +++ b/poulpy-hal/src/reference/vec_znx/add.rs @@ -91,7 +91,7 @@ pub fn bench_vec_znx_add(c: &mut Criterion, label: &str) where Module: VecZnxAdd + ModuleNew, { - let group_name: String = format!("vec_znx_add::{}", label); + let group_name: String = format!("vec_znx_add::{label}"); let mut group = c.benchmark_group(group_name); @@ -136,7 +136,7 @@ pub fn bench_vec_znx_add_inplace(c: &mut Criterion, label: &str) where Module: VecZnxAddInplace + ModuleNew, { - let group_name: String = format!("vec_znx_add_inplace::{}", label); + let group_name: String = format!("vec_znx_add_inplace::{label}"); let mut group = c.benchmark_group(group_name); diff --git a/poulpy-hal/src/reference/vec_znx/add_scalar.rs b/poulpy-hal/src/reference/vec_znx/add_scalar.rs index 68f830e..b35d853 100644 --- a/poulpy-hal/src/reference/vec_znx/add_scalar.rs +++ b/poulpy-hal/src/reference/vec_znx/add_scalar.rs @@ -18,12 +18,7 @@ where #[cfg(debug_assertions)] { - assert!( - b_limb < min_size, - "b_limb: {} > min_size: {}", - b_limb, - min_size - ); + assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}"); } for j in 0..min_size { diff --git a/poulpy-hal/src/reference/vec_znx/automorphism.rs b/poulpy-hal/src/reference/vec_znx/automorphism.rs index 91069cb..4ec1dc1 100644 --- a/poulpy-hal/src/reference/vec_znx/automorphism.rs +++ b/poulpy-hal/src/reference/vec_znx/automorphism.rs @@ -63,7 +63,7 @@ pub fn bench_vec_znx_automorphism(c: &mut Criterion, label: &str) where Module: VecZnxAutomorphism + ModuleNew, { - let group_name: String = format!("vec_znx_automorphism::{}", label); + let group_name: String = format!("vec_znx_automorphism::{label}"); let mut group = c.benchmark_group(group_name); @@ -108,7 +108,7 @@ where Module: VecZnxAutomorphismInplace + VecZnxAutomorphismInplaceTmpBytes + ModuleNew, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_automorphism_inplace::{}", label); + let group_name: String = format!("vec_znx_automorphism_inplace::{label}"); let mut group = c.benchmark_group(group_name); diff --git a/poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs b/poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs index b07599d..a4738c6 100644 --- a/poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs +++ b/poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs @@ -9,8 +9,8 @@ use crate::{ }, layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, reference::{ - vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace}, - znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero}, + vec_znx::{vec_znx_rotate, vec_znx_sub_inplace}, + znx::{ZnxNegate, ZnxRotate, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero}, }, source::Source, }; @@ -23,16 +23,16 @@ pub fn vec_znx_mul_xp_minus_one(p: i64, res: &mut R, res_col: usiz where R: VecZnxToMut, A: VecZnxToRef, - ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace, + ZNXARI: ZnxRotate + ZnxZero + ZnxSubInplace, { vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col); - vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col); + vec_znx_sub_inplace::<_, _, ZNXARI>(res, res_col, a, a_col); } pub fn vec_znx_mul_xp_minus_one_inplace(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64]) where R: VecZnxToMut, - ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace, + ZNXARI: ZnxRotate + ZnxNegate + ZnxSubNegateInplace, { let mut res: VecZnx<&mut [u8]> = res.to_mut(); #[cfg(debug_assertions)] @@ -41,7 +41,7 @@ where } for j in 0..res.size() { ZNXARI::znx_rotate(p, tmp, res.at(res_col, j)); - ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp); + ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), tmp); } } @@ -49,7 +49,7 @@ pub fn bench_vec_znx_mul_xp_minus_one(c: &mut Criterion, label: &str where Module: VecZnxMulXpMinusOne + ModuleNew, { - let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label); + let group_name: String = format!("vec_znx_mul_xp_minus_one::{label}"); let mut group = c.benchmark_group(group_name); @@ -94,7 +94,7 @@ where Module: VecZnxMulXpMinusOneInplace + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label); + let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{label}"); let mut group = c.benchmark_group(group_name); diff --git a/poulpy-hal/src/reference/vec_znx/negate.rs b/poulpy-hal/src/reference/vec_znx/negate.rs index f446467..5c1866a 100644 --- a/poulpy-hal/src/reference/vec_znx/negate.rs +++ b/poulpy-hal/src/reference/vec_znx/negate.rs @@ -49,7 +49,7 @@ pub fn bench_vec_znx_negate(c: &mut Criterion, label: &str) where Module: VecZnxNegate + ModuleNew, { - let group_name: String = format!("vec_znx_negate::{}", label); + let group_name: String = format!("vec_znx_negate::{label}"); let mut group = c.benchmark_group(group_name); @@ -93,7 +93,7 @@ pub fn bench_vec_znx_negate_inplace(c: &mut Criterion, label: &str) where Module: VecZnxNegateInplace + ModuleNew, { - let group_name: String = format!("vec_znx_negate_inplace::{}", label); + let group_name: String = format!("vec_znx_negate_inplace::{label}"); let mut group = c.benchmark_group(group_name); diff --git a/poulpy-hal/src/reference/vec_znx/normalize.rs b/poulpy-hal/src/reference/vec_znx/normalize.rs index 98795b8..a62c106 100644 --- a/poulpy-hal/src/reference/vec_znx/normalize.rs +++ b/poulpy-hal/src/reference/vec_znx/normalize.rs @@ -6,71 +6,204 @@ use crate::{ api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, reference::znx::{ - ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, - ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, - ZnxZero, + ZnxAddInplace, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, + ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, + ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxZero, }, source::Source, }; pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize { - n * size_of::() + 2 * n * size_of::() } -pub fn vec_znx_normalize(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) -where +pub fn vec_znx_normalize( + res_base2k: usize, + res: &mut R, + res_col: usize, + a_base2k: usize, + a: &A, + a_col: usize, + carry: &mut [i64], +) where R: VecZnxToMut, A: VecZnxToRef, ZNXARI: ZnxZero + + ZnxCopy + + ZnxAddInplace + + ZnxMulPowerOfTwoInplace + ZnxNormalizeFirstStepCarryOnly + ZnxNormalizeMiddleStepCarryOnly + ZnxNormalizeMiddleStep + ZnxNormalizeFinalStep - + ZnxNormalizeFirstStep, + + ZnxNormalizeFirstStep + + ZnxExtractDigitAddMul + + ZnxNormalizeDigit, { let mut res: VecZnx<&mut [u8]> = res.to_mut(); let a: VecZnx<&[u8]> = a.to_ref(); #[cfg(debug_assertions)] { - assert!(carry.len() >= res.n()); + assert!(carry.len() >= 2 * res.n()); + assert_eq!(res.n(), a.n()); } + let n: usize = res.n(); let res_size: usize = res.size(); - let a_size = a.size(); + let a_size: usize = a.size(); - if a_size > res_size { - for j in (res_size..a_size).rev() { - if j == a_size - 1 { - ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry); - } else { - ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry); + if res_base2k == a_base2k { + if a_size > res_size { + for j in (res_size..a_size).rev() { + if j == a_size - 1 { + ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry); + } + } + + for j in (1..res_size).rev() { + ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + } + + ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry); + } else { + for j in (0..a_size).rev() { + if j == a_size - 1 { + ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + } else if j == 0 { + ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + } + } + + for j in a_size..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); } } - - for j in (1..res_size).rev() { - ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); - } - - ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry); } else { - for j in (0..a_size).rev() { + let (a_norm, carry) = carry.split_at_mut(n); + + // Relevant limbs of res + let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size); + + // Relevant limbs of a + let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size); + + // Get carry for limbs of a that have higher precision than res + for j in (a_min_size..a_size).rev() { if j == a_size - 1 { - ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); - } else if j == 0 { - ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry); } else { - ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry); } } - for j in a_size..res_size { + if a_min_size == a_size { + ZNXARI::znx_zero(carry); + } + + // Maximum relevant precision of a + let a_prec: usize = a_min_size * a_base2k; + + // Maximum relevant precision of res + let res_prec: usize = res_min_size * res_base2k; + + // Res limb index + let mut res_idx: usize = res_min_size - 1; + + // Trackers: wow much of res is left to be populated + // for the current limb. + let mut res_left: usize = res_base2k; + + for j in (0..a_min_size).rev() { + // Trackers: wow much of a_norm is left to + // be flushed on res. + let mut a_left: usize = a_base2k; + + // Normalizes the j-th limb of a and store the results into a_norm. + // This step is required to avoid overflow in the next step, + // which assumes that |a| is bounded by 2^{a_base2k -1}. + if j != 0 { + ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry); + } else { + ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry); + } + + // In the first iteration we need to match the precision of the input/output. + // If a_min_size * a_base2k > res_min_size * res_base2k + // then divround a_norm by the difference of precision and + // acts like if a_norm has already been partially consummed. + // Else acts like if res has been already populated + // by the difference. + if j == a_min_size - 1 { + if a_prec > res_prec { + ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm); + a_left -= a_prec - res_prec; + } else if res_prec > a_prec { + res_left -= res_prec - a_prec; + } + } + + // Flushes a into res + loop { + // Selects the maximum amount of a that can be flushed + let a_take: usize = a_base2k.min(a_left).min(res_left); + + // Output limb + let res_slice: &mut [i64] = res.at_mut(res_col, res_idx); + + // Scaling of the value to flush + let lsh: usize = res_base2k - res_left; + + // Extract the bits to flush on the output and updates + // a_norm accordingly. + ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm); + + // Updates the trackers + a_left -= a_take; + res_left -= a_take; + + // If the current limb of res is full, + // then normalizes this limb and adds + // the carry on a_norm. + if res_left == 0 { + // Updates tracker + res_left += res_base2k; + + // Normalizes res and propagates the carry on a. + ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm); + + // If we reached the last limb of res breaks, + // but we might rerun the above loop if the + // base2k of a is much smaller than the base2k + // of res. + if res_idx == 0 { + ZNXARI::znx_add_inplace(carry, a_norm); + break; + } + + // Else updates the limb index of res. + res_idx -= 1 + } + + // If a_norm is exhausted, breaks the loop. + if a_left == 0 { + ZNXARI::znx_add_inplace(carry, a_norm); + break; + } + } + } + + for j in res_min_size..res_size { ZNXARI::znx_zero(res.at_mut(res_col, j)); } } } -pub fn vec_znx_normalize_inplace(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64]) +pub fn vec_znx_normalize_inplace(base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) where ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace, { @@ -85,11 +218,11 @@ where for j in (0..res_size).rev() { if j == res_size - 1 { - ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_first_step_inplace(base2k, 0, res.at_mut(res_col, j), carry); } else if j == 0 { - ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_final_step_inplace(base2k, 0, res.at_mut(res_col, j), carry); } else { - ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_middle_step_inplace(base2k, 0, res.at_mut(res_col, j), carry); } } } @@ -99,7 +232,7 @@ where Module: VecZnxNormalize + ModuleNew + VecZnxNormalizeTmpBytes, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_normalize::{}", label); + let group_name: String = format!("vec_znx_normalize::{label}"); let mut group = c.benchmark_group(group_name); @@ -114,7 +247,7 @@ where let module: Module = Module::::new(n as u64); - let basek: usize = 50; + let base2k: usize = 50; let mut source: Source = Source::new([0u8; 32]); @@ -129,7 +262,7 @@ where move || { for i in 0..cols { - module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow()); + module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow()); } black_box(()); } @@ -149,7 +282,7 @@ where Module: VecZnxNormalizeInplace + ModuleNew + VecZnxNormalizeTmpBytes, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_normalize_inplace::{}", label); + let group_name: String = format!("vec_znx_normalize_inplace::{label}"); let mut group = c.benchmark_group(group_name); @@ -164,7 +297,7 @@ where let module: Module = Module::::new(n as u64); - let basek: usize = 50; + let base2k: usize = 50; let mut source: Source = Source::new([0u8; 32]); @@ -177,7 +310,7 @@ where move || { for i in 0..cols { - module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow()); + module.vec_znx_normalize_inplace(base2k, &mut a, i, scratch.borrow()); } black_box(()); } @@ -191,3 +324,83 @@ where group.finish(); } + +#[test] +fn test_vec_znx_normalize_conv() { + let n: usize = 8; + + let mut carry: Vec = vec![0i64; 2 * n]; + + use crate::reference::znx::ZnxRef; + use rug::ops::SubAssignRound; + use rug::{Float, float::Round}; + + let mut source: Source = Source::new([1u8; 32]); + + let prec: usize = 128; + + let mut data: Vec = vec![0i128; n]; + + data.iter_mut().for_each(|x| *x = source.next_i128()); + + for start_base2k in 1..50 { + for end_base2k in 1..50 { + let end_size: usize = prec.div_ceil(end_base2k); + + let mut want: VecZnx> = VecZnx::alloc(n, 1, end_size); + want.encode_vec_i128(end_base2k, 0, prec, &data); + vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry); + + // Creates a temporary poly where encoding is in start_base2k + let mut tmp: VecZnx> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k)); + tmp.encode_vec_i128(start_base2k, 0, prec, &data); + + vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry); + + let mut data_tmp: Vec = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect(); + tmp.decode_vec_float(start_base2k, 0, &mut data_tmp); + + let mut have: VecZnx> = VecZnx::alloc(n, 1, end_size); + vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry); + + let out_prec: u32 = (end_size * end_base2k) as u32; + + let mut data_want: Vec = (0..n) + .map(|_| Float::with_val(out_prec as u32, 0)) + .collect(); + let mut data_res: Vec = (0..n) + .map(|_| Float::with_val(out_prec as u32, 0)) + .collect(); + + have.decode_vec_float(end_base2k, 0, &mut data_want); + want.decode_vec_float(end_base2k, 0, &mut data_res); + + for i in 0..n { + let mut err: Float = data_want[i].clone(); + err.sub_assign_round(&data_res[i], Round::Nearest); + err = err.abs(); + + // println!( + // "want: {} have: {} tmp: {} (want-have): {}", + // data_want[i].to_f64(), + // data_res[i].to_f64(), + // data_tmp[i].to_f64(), + // err.to_f64() + // ); + + let err_log2: f64 = err + .clone() + .max(&Float::with_val(prec as u32, 1e-60)) + .log2() + .to_f64(); + + assert!( + err_log2 <= -(out_prec as f64) + 1., + "{} {}", + err_log2, + -(out_prec as f64) + 1. + ) + } + } + } +} diff --git a/poulpy-hal/src/reference/vec_znx/rotate.rs b/poulpy-hal/src/reference/vec_znx/rotate.rs index 78ef17c..3b29677 100644 --- a/poulpy-hal/src/reference/vec_znx/rotate.rs +++ b/poulpy-hal/src/reference/vec_znx/rotate.rs @@ -61,7 +61,7 @@ pub fn bench_vec_znx_rotate(c: &mut Criterion, label: &str) where Module: VecZnxRotate + ModuleNew, { - let group_name: String = format!("vec_znx_rotate::{}", label); + let group_name: String = format!("vec_znx_rotate::{label}"); let mut group = c.benchmark_group(group_name); @@ -106,7 +106,7 @@ where Module: VecZnxRotateInplace + VecZnxRotateInplaceTmpBytes + ModuleNew, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_rotate_inplace::{}", label); + let group_name: String = format!("vec_znx_rotate_inplace::{label}"); let mut group = c.benchmark_group(group_name); diff --git a/poulpy-hal/src/reference/vec_znx/sampling.rs b/poulpy-hal/src/reference/vec_znx/sampling.rs index d29edab..d1e12eb 100644 --- a/poulpy-hal/src/reference/vec_znx/sampling.rs +++ b/poulpy-hal/src/reference/vec_znx/sampling.rs @@ -4,18 +4,18 @@ use crate::{ source::Source, }; -pub fn vec_znx_fill_uniform_ref(basek: usize, res: &mut R, res_col: usize, source: &mut Source) +pub fn vec_znx_fill_uniform_ref(base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut, { let mut res: VecZnx<&mut [u8]> = res.to_mut(); for j in 0..res.size() { - znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source) + znx_fill_uniform_ref(base2k, res.at_mut(res_col, j), source) } } pub fn vec_znx_fill_normal_ref( - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -32,8 +32,8 @@ pub fn vec_znx_fill_normal_ref( (bound.log2().ceil() as i64) ); - let limb: usize = k.div_ceil(basek) - 1; - let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + let limb: usize = k.div_ceil(base2k) - 1; + let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; znx_fill_normal_f64_ref( res.at_mut(res_col, limb), sigma * scale, @@ -42,8 +42,15 @@ pub fn vec_znx_fill_normal_ref( ) } -pub fn vec_znx_add_normal_ref(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source) -where +pub fn vec_znx_add_normal_ref( + base2k: usize, + res: &mut R, + res_col: usize, + k: usize, + sigma: f64, + bound: f64, + source: &mut Source, +) where R: VecZnxToMut, { let mut res: VecZnx<&mut [u8]> = res.to_mut(); @@ -53,8 +60,8 @@ where (bound.log2().ceil() as i64) ); - let limb: usize = k.div_ceil(basek) - 1; - let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + let limb: usize = k.div_ceil(base2k) - 1; + let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; znx_add_normal_f64_ref( res.at_mut(res_col, limb), sigma * scale, diff --git a/poulpy-hal/src/reference/vec_znx/shift.rs b/poulpy-hal/src/reference/vec_znx/shift.rs index 5b64d46..a13b982 100644 --- a/poulpy-hal/src/reference/vec_znx/shift.rs +++ b/poulpy-hal/src/reference/vec_znx/shift.rs @@ -20,7 +20,7 @@ pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize { n * size_of::() } -pub fn vec_znx_lsh_inplace(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) +pub fn vec_znx_lsh_inplace(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) where R: VecZnxToMut, ZNXARI: ZnxZero @@ -35,8 +35,8 @@ where let n: usize = res.n(); let cols: usize = res.cols(); let size: usize = res.size(); - let steps: usize = k / basek; - let k_rem: usize = k % basek; + let steps: usize = k / base2k; + let k_rem: usize = k % base2k; if steps >= size { for j in 0..size { @@ -45,7 +45,7 @@ where return; } - // Inplace shift of limbs by a k/basek + // Inplace shift of limbs by a k/base2k if steps > 0 { let start: usize = n * res_col; let end: usize = start + n; @@ -65,21 +65,21 @@ where } } - // Inplace normalization with left shift of k % basek - if !k.is_multiple_of(basek) { + // Inplace normalization with left shift of k % base2k + if !k.is_multiple_of(base2k) { for j in (0..size - steps).rev() { if j == size - steps - 1 { - ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry); } else if j == 0 { - ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry); } else { - ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry); } } } } -pub fn vec_znx_lsh(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) +pub fn vec_znx_lsh(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) where R: VecZnxToMut, A: VecZnxToRef, @@ -90,8 +90,8 @@ where let res_size: usize = res.size(); let a_size = a.size(); - let steps: usize = k / basek; - let k_rem: usize = k % basek; + let steps: usize = k / base2k; + let k_rem: usize = k % base2k; if steps >= res_size.min(a_size) { for j in 0..res_size { @@ -103,12 +103,12 @@ where let min_size: usize = a_size.min(res_size) - steps; // Simply a left shifted normalization of limbs - // by k/basek and intra-limb by basek - k%basek - if !k.is_multiple_of(basek) { + // by k/base2k and intra-limb by base2k - k%base2k + if !k.is_multiple_of(base2k) { for j in (0..min_size).rev() { if j == min_size - 1 { ZNXARI::znx_normalize_first_step( - basek, + base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), @@ -116,7 +116,7 @@ where ); } else if j == 0 { ZNXARI::znx_normalize_final_step( - basek, + base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), @@ -124,7 +124,7 @@ where ); } else { ZNXARI::znx_normalize_middle_step( - basek, + base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), @@ -133,7 +133,7 @@ where } } } else { - // If k % basek = 0, then this is simply a copy. + // If k % base2k = 0, then this is simply a copy. for j in (0..min_size).rev() { ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps)); } @@ -149,7 +149,7 @@ pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize { n * size_of::() } -pub fn vec_znx_rsh_inplace(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) +pub fn vec_znx_rsh_inplace(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) where R: VecZnxToMut, ZNXARI: ZnxZero @@ -166,8 +166,8 @@ where let cols: usize = res.cols(); let size: usize = res.size(); - let mut steps: usize = k / basek; - let k_rem: usize = k % basek; + let mut steps: usize = k / base2k; + let k_rem: usize = k % base2k; if k == 0 { return; @@ -184,8 +184,8 @@ where let end: usize = start + n; let slice_size: usize = n * cols; - if !k.is_multiple_of(basek) { - // We rsh by an additional basek and then lsh by basek-k + if !k.is_multiple_of(base2k) { + // We rsh by an additional base2k and then lsh by base2k-k // Allows to re-use efficient normalization code, avoids // avoids overflows & produce output that is normalized steps += 1; @@ -194,9 +194,9 @@ where // but the carry still need to be computed. (size - steps..size).rev().for_each(|j| { if j == size - 1 { - ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry); + ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry); } else { - ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry); + ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry); } }); @@ -206,20 +206,20 @@ where let (lhs, rhs) = res_raw.split_at_mut(slice_size * j); let rhs_slice: &mut [i64] = &mut rhs[start..end]; let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end]; - ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry); + ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry); }); // Propagates carry on the rest of the limbs of res for j in (0..steps).rev() { ZNXARI::znx_zero(res.at_mut(res_col, j)); if j == 0 { - ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry); } else { - ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry); } } } else { - // Shift by multiples of basek + // Shift by multiples of base2k let res_raw: &mut [i64] = res.raw_mut(); (steps..size).rev().for_each(|j| { let (lhs, rhs) = res_raw.split_at_mut(slice_size * j); @@ -236,7 +236,7 @@ where } } -pub fn vec_znx_rsh(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) +pub fn vec_znx_rsh(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) where R: VecZnxToMut, A: VecZnxToRef, @@ -256,8 +256,8 @@ where let res_size: usize = res.size(); let a_size: usize = a.size(); - let mut steps: usize = k / basek; - let k_rem: usize = k % basek; + let mut steps: usize = k / base2k; + let k_rem: usize = k % base2k; if k == 0 { vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col); @@ -271,8 +271,8 @@ where return; } - if !k.is_multiple_of(basek) { - // We rsh by an additional basek and then lsh by basek-k + if !k.is_multiple_of(base2k) { + // We rsh by an additional base2k and then lsh by base2k-k // Allows to re-use efficient normalization code, avoids // avoids overflows & produce output that is normalized steps += 1; @@ -281,9 +281,9 @@ where // but the carry still need to be computed. for j in (res_size..a_size + steps).rev() { if j == a_size + steps - 1 { - ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry); + ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry); } else { - ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry); + ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry); } } @@ -300,16 +300,16 @@ where // Case if no limb of a was previously discarded if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 { ZNXARI::znx_normalize_first_step( - basek, - basek - k_rem, + base2k, + base2k - k_rem, res.at_mut(res_col, j), a.at(a_col, j - steps), carry, ); } else { ZNXARI::znx_normalize_middle_step( - basek, - basek - k_rem, + base2k, + base2k - k_rem, res.at_mut(res_col, j), a.at(a_col, j - steps), carry, @@ -321,9 +321,9 @@ where for j in (0..steps).rev() { ZNXARI::znx_zero(res.at_mut(res_col, j)); if j == 0 { - ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry); } else { - ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry); + ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry); } } } else { @@ -351,7 +351,7 @@ where Module: ModuleNew + VecZnxLshInplace, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_lsh_inplace::{}", label); + let group_name: String = format!("vec_znx_lsh_inplace::{label}"); let mut group = c.benchmark_group(group_name); @@ -366,7 +366,7 @@ where let module: Module = Module::::new(n as u64); - let basek: usize = 50; + let base2k: usize = 50; let mut source: Source = Source::new([0u8; 32]); @@ -381,7 +381,7 @@ where move || { for i in 0..cols { - module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow()); + module.vec_znx_lsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow()); } black_box(()); } @@ -401,7 +401,7 @@ where Module: VecZnxLsh + ModuleNew, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_lsh::{}", label); + let group_name: String = format!("vec_znx_lsh::{label}"); let mut group = c.benchmark_group(group_name); @@ -416,7 +416,7 @@ where let module: Module = Module::::new(n as u64); - let basek: usize = 50; + let base2k: usize = 50; let mut source: Source = Source::new([0u8; 32]); @@ -431,7 +431,7 @@ where move || { for i in 0..cols { - module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow()); + module.vec_znx_lsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow()); } black_box(()); } @@ -451,7 +451,7 @@ where Module: VecZnxRshInplace + ModuleNew, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_rsh_inplace::{}", label); + let group_name: String = format!("vec_znx_rsh_inplace::{label}"); let mut group = c.benchmark_group(group_name); @@ -466,7 +466,7 @@ where let module: Module = Module::::new(n as u64); - let basek: usize = 50; + let base2k: usize = 50; let mut source: Source = Source::new([0u8; 32]); @@ -481,7 +481,7 @@ where move || { for i in 0..cols { - module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow()); + module.vec_znx_rsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow()); } black_box(()); } @@ -501,7 +501,7 @@ where Module: VecZnxRsh + ModuleNew, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { - let group_name: String = format!("vec_znx_rsh::{}", label); + let group_name: String = format!("vec_znx_rsh::{label}"); let mut group = c.benchmark_group(group_name); @@ -516,7 +516,7 @@ where let module: Module = Module::::new(n as u64); - let basek: usize = 50; + let base2k: usize = 50; let mut source: Source = Source::new([0u8; 32]); @@ -531,7 +531,7 @@ where move || { for i in 0..cols { - module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow()); + module.vec_znx_rsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow()); } black_box(()); } @@ -553,7 +553,7 @@ mod tests { reference::{ vec_znx::{ vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace, - vec_znx_sub_ab_inplace, + vec_znx_sub_inplace, }, znx::ZnxRef, }, @@ -574,20 +574,20 @@ mod tests { let mut carry: Vec = vec![0i64; n]; - let basek: usize = 50; + let base2k: usize = 50; for k in 0..256 { a.fill_uniform(50, &mut source); for i in 0..cols { - vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry); + vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry); vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i); } for i in 0..cols { - vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry); - vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry); - vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry); + vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, i, &mut carry); + vec_znx_lsh::<_, _, ZnxRef>(base2k, k, &mut res_test, i, &a, i, &mut carry); + vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, i, &mut carry); } assert_eq!(res_ref, res_test); @@ -606,7 +606,7 @@ mod tests { let mut carry: Vec = vec![0i64; n]; - let basek: usize = 50; + let base2k: usize = 50; let mut source: Source = Source::new([0u8; 32]); @@ -615,29 +615,29 @@ mod tests { for a_size in [res_size - 1, res_size, res_size + 1] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - for k in 0..res_size * basek { + for k in 0..res_size * base2k { a.fill_uniform(50, &mut source); for i in 0..cols { - vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry); + vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry); vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i); } res_test.fill_uniform(50, &mut source); for j in 0..cols { - vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry); - vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry); + vec_znx_rsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry); + vec_znx_rsh::<_, _, ZnxRef>(base2k, k, &mut res_test, j, &a, j, &mut carry); } for j in 0..cols { - vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry); - vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry); + vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry); + vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_test, j, &mut carry); } // Case where res has enough to fully store a right shifted without any loss // In this case we can check exact equality. - if a_size + k.div_ceil(basek) <= res_size { + if a_size + k.div_ceil(base2k) <= res_size { assert_eq!(res_ref, res_test); for i in 0..cols { @@ -656,14 +656,14 @@ mod tests { // res. } else { for j in 0..cols { - vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j); - vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j); + vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j); + vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j); - vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry); - vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry); + vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry); + vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry); - assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64); - assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64); + assert!(res_ref.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64); + assert!(res_test.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64); } } } diff --git a/poulpy-hal/src/reference/vec_znx/sub.rs b/poulpy-hal/src/reference/vec_znx/sub.rs index e9341ff..5497bcb 100644 --- a/poulpy-hal/src/reference/vec_znx/sub.rs +++ b/poulpy-hal/src/reference/vec_znx/sub.rs @@ -3,10 +3,10 @@ use std::hint::black_box; use criterion::{BenchmarkId, Criterion}; use crate::{ - api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace}, + api::{ModuleNew, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace}, layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, - oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl}, - reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero}, + oep::{ModuleNewImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl}, + reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero}, source::Source, }; @@ -64,11 +64,11 @@ where } } -pub fn vec_znx_sub_ab_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +pub fn vec_znx_sub_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, - ZNXARI: ZnxSubABInplace, + ZNXARI: ZnxSubInplace, { let a: VecZnx<&[u8]> = a.to_ref(); let mut res: VecZnx<&mut [u8]> = res.to_mut(); @@ -84,15 +84,15 @@ where let sum_size: usize = a_size.min(res_size); for j in 0..sum_size { - ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + ZNXARI::znx_sub_inplace(res.at_mut(res_col, j), a.at(a_col, j)); } } -pub fn vec_znx_sub_ba_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +pub fn vec_znx_sub_negate_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, - ZNXARI: ZnxSubBAInplace + ZnxNegateInplace, + ZNXARI: ZnxSubNegateInplace + ZnxNegateInplace, { let a: VecZnx<&[u8]> = a.to_ref(); let mut res: VecZnx<&mut [u8]> = res.to_mut(); @@ -108,7 +108,7 @@ where let sum_size: usize = a_size.min(res_size); for j in 0..sum_size { - ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), a.at(a_col, j)); } for j in sum_size..res_size { @@ -120,7 +120,7 @@ pub fn bench_vec_znx_sub(c: &mut Criterion, label: &str) where B: Backend + ModuleNewImpl + VecZnxSubImpl, { - let group_name: String = format!("vec_znx_sub::{}", label); + let group_name: String = format!("vec_znx_sub::{label}"); let mut group = c.benchmark_group(group_name); @@ -161,17 +161,17 @@ where group.finish(); } -pub fn bench_vec_znx_sub_ab_inplace(c: &mut Criterion, label: &str) +pub fn bench_vec_znx_sub_inplace(c: &mut Criterion, label: &str) where - B: Backend + ModuleNewImpl + VecZnxSubABInplaceImpl, + B: Backend + ModuleNewImpl + VecZnxSubInplaceImpl, { - let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label); + let group_name: String = format!("vec_znx_sub_inplace::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where - Module: VecZnxSubABInplace + ModuleNew, + Module: VecZnxSubInplace + ModuleNew, { let n: usize = 1 << params[0]; let cols: usize = params[1]; @@ -190,7 +190,7 @@ where move || { for i in 0..cols { - module.vec_znx_sub_ab_inplace(&mut b, i, &a, i); + module.vec_znx_sub_inplace(&mut b, i, &a, i); } black_box(()); } @@ -205,17 +205,17 @@ where group.finish(); } -pub fn bench_vec_znx_sub_ba_inplace(c: &mut Criterion, label: &str) +pub fn bench_vec_znx_sub_negate_inplace(c: &mut Criterion, label: &str) where - B: Backend + ModuleNewImpl + VecZnxSubBAInplaceImpl, + B: Backend + ModuleNewImpl + VecZnxSubNegateInplaceImpl, { - let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label); + let group_name: String = format!("vec_znx_sub_negate_inplace::{label}"); let mut group = c.benchmark_group(group_name); fn runner(params: [usize; 3]) -> impl FnMut() where - Module: VecZnxSubBAInplace + ModuleNew, + Module: VecZnxSubNegateInplace + ModuleNew, { let n: usize = 1 << params[0]; let cols: usize = params[1]; @@ -234,7 +234,7 @@ where move || { for i in 0..cols { - module.vec_znx_sub_ba_inplace(&mut b, i, &a, i); + module.vec_znx_sub_negate_inplace(&mut b, i, &a, i); } black_box(()); } diff --git a/poulpy-hal/src/reference/vec_znx/sub_scalar.rs b/poulpy-hal/src/reference/vec_znx/sub_scalar.rs index 04d405e..872e07b 100644 --- a/poulpy-hal/src/reference/vec_znx/sub_scalar.rs +++ b/poulpy-hal/src/reference/vec_znx/sub_scalar.rs @@ -1,7 +1,7 @@ use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef}; use crate::{ layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut}, - reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero}, + reference::znx::{ZnxSub, ZnxSubInplace, ZnxZero}, }; pub fn vec_znx_sub_scalar(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize) @@ -19,12 +19,7 @@ where #[cfg(debug_assertions)] { - assert!( - b_limb < min_size, - "b_limb: {} > min_size: {}", - b_limb, - min_size - ); + assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}"); } for j in 0..min_size { @@ -44,7 +39,7 @@ pub fn vec_znx_sub_scalar_inplace(res: &mut R, res_col: usize, res where R: VecZnxToMut, A: ScalarZnxToRef, - ZNXARI: ZnxSubABInplace, + ZNXARI: ZnxSubInplace, { let a: ScalarZnx<&[u8]> = a.to_ref(); let mut res: VecZnx<&mut [u8]> = res.to_mut(); @@ -54,5 +49,5 @@ where assert!(res_limb < res.size()); } - ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0)); + ZNXARI::znx_sub_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0)); } diff --git a/poulpy-hal/src/reference/zn/normalization.rs b/poulpy-hal/src/reference/zn/normalization.rs index 4412369..83cfeb7 100644 --- a/poulpy-hal/src/reference/zn/normalization.rs +++ b/poulpy-hal/src/reference/zn/normalization.rs @@ -9,7 +9,7 @@ pub fn zn_normalize_tmp_bytes(n: usize) -> usize { n * size_of::() } -pub fn zn_normalize_inplace(n: usize, basek: usize, res: &mut R, res_col: usize, carry: &mut [i64]) +pub fn zn_normalize_inplace(n: usize, base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) where R: ZnToMut, ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace, @@ -27,11 +27,11 @@ where let out = &mut res.at_mut(res_col, j)[..n]; if j == res_size - 1 { - ARI::znx_normalize_first_step_inplace(basek, 0, out, carry); + ARI::znx_normalize_first_step_inplace(base2k, 0, out, carry); } else if j == 0 { - ARI::znx_normalize_final_step_inplace(basek, 0, out, carry); + ARI::znx_normalize_final_step_inplace(base2k, 0, out, carry); } else { - ARI::znx_normalize_middle_step_inplace(basek, 0, out, carry); + ARI::znx_normalize_middle_step_inplace(base2k, 0, out, carry); } } } @@ -43,7 +43,7 @@ where { let mut source: Source = Source::new([0u8; 32]); let cols: usize = 2; - let basek: usize = 12; + let base2k: usize = 12; let n = 33; @@ -63,8 +63,8 @@ where // Reference for i in 0..cols { - zn_normalize_inplace::<_, ZnxRef>(n, basek, &mut res_0, i, &mut carry); - module.zn_normalize_inplace(n, basek, &mut res_1, i, scratch.borrow()); + zn_normalize_inplace::<_, ZnxRef>(n, base2k, &mut res_0, i, &mut carry); + module.zn_normalize_inplace(n, base2k, &mut res_1, i, scratch.borrow()); } assert_eq!(res_0.raw(), res_1.raw()); diff --git a/poulpy-hal/src/reference/zn/sampling.rs b/poulpy-hal/src/reference/zn/sampling.rs index 9c46f7a..b376dcc 100644 --- a/poulpy-hal/src/reference/zn/sampling.rs +++ b/poulpy-hal/src/reference/zn/sampling.rs @@ -4,20 +4,20 @@ use crate::{ source::Source, }; -pub fn zn_fill_uniform(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) +pub fn zn_fill_uniform(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut, { let mut res: Zn<&mut [u8]> = res.to_mut(); for j in 0..res.size() { - znx_fill_uniform_ref(basek, &mut res.at_mut(res_col, j)[..n], source) + znx_fill_uniform_ref(base2k, &mut res.at_mut(res_col, j)[..n], source) } } #[allow(clippy::too_many_arguments)] pub fn zn_fill_normal( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -34,8 +34,8 @@ pub fn zn_fill_normal( (bound.log2().ceil() as i64) ); - let limb: usize = k.div_ceil(basek) - 1; - let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + let limb: usize = k.div_ceil(base2k) - 1; + let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; znx_fill_normal_f64_ref( &mut res.at_mut(res_col, limb)[..n], sigma * scale, @@ -47,7 +47,7 @@ pub fn zn_fill_normal( #[allow(clippy::too_many_arguments)] pub fn zn_add_normal( n: usize, - basek: usize, + base2k: usize, res: &mut R, res_col: usize, k: usize, @@ -64,8 +64,8 @@ pub fn zn_add_normal( (bound.log2().ceil() as i64) ); - let limb: usize = k.div_ceil(basek) - 1; - let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + let limb: usize = k.div_ceil(base2k) - 1; + let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; znx_add_normal_f64_ref( &mut res.at_mut(res_col, limb)[..n], sigma * scale, diff --git a/poulpy-hal/src/reference/znx/arithmetic_ref.rs b/poulpy-hal/src/reference/znx/arithmetic_ref.rs index ba21ede..9955b94 100644 --- a/poulpy-hal/src/reference/znx/arithmetic_ref.rs +++ b/poulpy-hal/src/reference/znx/arithmetic_ref.rs @@ -1,8 +1,9 @@ use crate::reference::znx::{ - ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep, - ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, - ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubABInplace, - ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, + ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo, + ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, + ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, + ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxSwitchRing, + ZnxZero, add::{znx_add_inplace_ref, znx_add_ref}, automorphism::znx_automorphism_ref, copy::znx_copy_ref, @@ -12,9 +13,11 @@ use crate::reference::znx::{ znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, }, - sub::{znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref}, + sub::{znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref}, switch_ring::znx_switch_ring_ref, zero::znx_zero_ref, + znx_extract_digit_addmul_ref, znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref, + znx_normalize_digit_ref, }; pub struct ZnxRef {} @@ -40,17 +43,17 @@ impl ZnxSub for ZnxRef { } } -impl ZnxSubABInplace for ZnxRef { +impl ZnxSubInplace for ZnxRef { #[inline(always)] - fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) { - znx_sub_ab_inplace_ref(res, a); + fn znx_sub_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_inplace_ref(res, a); } } -impl ZnxSubBAInplace for ZnxRef { +impl ZnxSubNegateInplace for ZnxRef { #[inline(always)] - fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) { - znx_sub_ba_inplace_ref(res, a); + fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_negate_inplace_ref(res, a); } } @@ -61,6 +64,27 @@ impl ZnxAutomorphism for ZnxRef { } } +impl ZnxMulPowerOfTwo for ZnxRef { + #[inline(always)] + fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) { + znx_mul_power_of_two_ref(k, res, a); + } +} + +impl ZnxMulAddPowerOfTwo for ZnxRef { + #[inline(always)] + fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) { + znx_mul_add_power_of_two_ref(k, res, a); + } +} + +impl ZnxMulPowerOfTwoInplace for ZnxRef { + #[inline(always)] + fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) { + znx_mul_power_of_two_inplace_ref(k, res); + } +} + impl ZnxCopy for ZnxRef { #[inline(always)] fn znx_copy(res: &mut [i64], a: &[i64]) { @@ -98,56 +122,70 @@ impl ZnxSwitchRing for ZnxRef { impl ZnxNormalizeFinalStep for ZnxRef { #[inline(always)] - fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { - znx_normalize_final_step_ref(basek, lsh, x, a, carry); + fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_final_step_ref(base2k, lsh, x, a, carry); } } impl ZnxNormalizeFinalStepInplace for ZnxRef { #[inline(always)] - fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { - znx_normalize_final_step_inplace_ref(basek, lsh, x, carry); + fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry); } } impl ZnxNormalizeFirstStep for ZnxRef { #[inline(always)] - fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { - znx_normalize_first_step_ref(basek, lsh, x, a, carry); + fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_ref(base2k, lsh, x, a, carry); } } impl ZnxNormalizeFirstStepCarryOnly for ZnxRef { #[inline(always)] - fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { - znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry); + fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry); } } impl ZnxNormalizeFirstStepInplace for ZnxRef { #[inline(always)] - fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { - znx_normalize_first_step_inplace_ref(basek, lsh, x, carry); + fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry); } } impl ZnxNormalizeMiddleStep for ZnxRef { #[inline(always)] - fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { - znx_normalize_middle_step_ref(basek, lsh, x, a, carry); + fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_ref(base2k, lsh, x, a, carry); } } impl ZnxNormalizeMiddleStepCarryOnly for ZnxRef { #[inline(always)] - fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { - znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry); + fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry); } } impl ZnxNormalizeMiddleStepInplace for ZnxRef { #[inline(always)] - fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { - znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry); + fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry); + } +} + +impl ZnxExtractDigitAddMul for ZnxRef { + #[inline(always)] + fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) { + znx_extract_digit_addmul_ref(base2k, lsh, res, src); + } +} + +impl ZnxNormalizeDigit for ZnxRef { + #[inline(always)] + fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) { + znx_normalize_digit_ref(base2k, res, src); } } diff --git a/poulpy-hal/src/reference/znx/mod.rs b/poulpy-hal/src/reference/znx/mod.rs index 9659e7d..493ba0b 100644 --- a/poulpy-hal/src/reference/znx/mod.rs +++ b/poulpy-hal/src/reference/znx/mod.rs @@ -2,6 +2,7 @@ mod add; mod arithmetic_ref; mod automorphism; mod copy; +mod mul; mod neg; mod normalization; mod rotate; @@ -14,6 +15,7 @@ pub use add::*; pub use arithmetic_ref::*; pub use automorphism::*; pub use copy::*; +pub use mul::*; pub use neg::*; pub use normalization::*; pub use rotate::*; @@ -35,12 +37,12 @@ pub trait ZnxSub { fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]); } -pub trait ZnxSubABInplace { - fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]); +pub trait ZnxSubInplace { + fn znx_sub_inplace(res: &mut [i64], a: &[i64]); } -pub trait ZnxSubBAInplace { - fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]); +pub trait ZnxSubNegateInplace { + fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]); } pub trait ZnxAutomorphism { @@ -67,38 +69,58 @@ pub trait ZnxZero { fn znx_zero(res: &mut [i64]); } +pub trait ZnxMulPowerOfTwo { + fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]); +} + +pub trait ZnxMulAddPowerOfTwo { + fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]); +} + +pub trait ZnxMulPowerOfTwoInplace { + fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]); +} + pub trait ZnxSwitchRing { fn znx_switch_ring(res: &mut [i64], a: &[i64]); } pub trait ZnxNormalizeFirstStepCarryOnly { - fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]); + fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]); } pub trait ZnxNormalizeFirstStepInplace { - fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]); + fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]); } pub trait ZnxNormalizeFirstStep { - fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]); + fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]); } pub trait ZnxNormalizeMiddleStepCarryOnly { - fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]); + fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]); } pub trait ZnxNormalizeMiddleStepInplace { - fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]); + fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]); } pub trait ZnxNormalizeMiddleStep { - fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]); + fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]); } pub trait ZnxNormalizeFinalStepInplace { - fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]); + fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]); } pub trait ZnxNormalizeFinalStep { - fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]); + fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]); +} + +pub trait ZnxExtractDigitAddMul { + fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]); +} + +pub trait ZnxNormalizeDigit { + fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]); } diff --git a/poulpy-hal/src/reference/znx/mul.rs b/poulpy-hal/src/reference/znx/mul.rs new file mode 100644 index 0000000..8d48c66 --- /dev/null +++ b/poulpy-hal/src/reference/znx/mul.rs @@ -0,0 +1,76 @@ +use crate::reference::znx::{znx_add_inplace_ref, znx_copy_ref}; + +pub fn znx_mul_power_of_two_ref(mut k: i64, res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + if k == 0 { + znx_copy_ref(res, a); + return; + } + + if k > 0 { + for (y, x) in res.iter_mut().zip(a.iter()) { + *y = *x << k + } + return; + } + + k = -k; + + for (y, x) in res.iter_mut().zip(a.iter()) { + let sign_bit: i64 = (x >> 63) & 1; + let bias: i64 = (1_i64 << (k - 1)) - sign_bit; + *y = (x + bias) >> k; + } +} + +pub fn znx_mul_power_of_two_inplace_ref(mut k: i64, res: &mut [i64]) { + if k == 0 { + return; + } + + if k > 0 { + for x in res.iter_mut() { + *x <<= k + } + return; + } + + k = -k; + + for x in res.iter_mut() { + let sign_bit: i64 = (*x >> 63) & 1; + let bias: i64 = (1_i64 << (k - 1)) - sign_bit; + *x = (*x + bias) >> k; + } +} + +pub fn znx_mul_add_power_of_two_ref(mut k: i64, res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + if k == 0 { + znx_add_inplace_ref(res, a); + return; + } + + if k > 0 { + for (y, x) in res.iter_mut().zip(a.iter()) { + *y += *x << k + } + return; + } + + k = -k; + + for (y, x) in res.iter_mut().zip(a.iter()) { + let sign_bit: i64 = (x >> 63) & 1; + let bias: i64 = (1_i64 << (k - 1)) - sign_bit; + *y += (x + bias) >> k; + } +} diff --git a/poulpy-hal/src/reference/znx/normalization.rs b/poulpy-hal/src/reference/znx/normalization.rs index e9f57cf..95100e4 100644 --- a/poulpy-hal/src/reference/znx/normalization.rs +++ b/poulpy-hal/src/reference/znx/normalization.rs @@ -1,199 +1,229 @@ use itertools::izip; #[inline(always)] -pub fn get_digit(basek: usize, x: i64) -> i64 { - (x << (u64::BITS - basek as u32)) >> (u64::BITS - basek as u32) +pub fn get_digit_i64(base2k: usize, x: i64) -> i64 { + (x << (u64::BITS - base2k as u32)) >> (u64::BITS - base2k as u32) } #[inline(always)] -pub fn get_carry(basek: usize, x: i64, digit: i64) -> i64 { - (x.wrapping_sub(digit)) >> basek +pub fn get_carry_i64(base2k: usize, x: i64, digit: i64) -> i64 { + (x.wrapping_sub(digit)) >> base2k } #[inline(always)] -pub fn znx_normalize_first_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { +pub fn get_digit_i128(base2k: usize, x: i128) -> i128 { + (x << (u128::BITS - base2k as u32)) >> (u128::BITS - base2k as u32) +} + +#[inline(always)] +pub fn get_carry_i128(base2k: usize, x: i128, digit: i128) -> i128 { + (x.wrapping_sub(digit)) >> base2k +} + +#[inline(always)] +pub fn znx_normalize_first_step_carry_only_ref(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { assert!(x.len() <= carry.len()); - assert!(lsh < basek); + assert!(lsh < base2k); } if lsh == 0 { x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { - *c = get_carry(basek, *x, get_digit(basek, *x)); + *c = get_carry_i64(base2k, *x, get_digit_i64(base2k, *x)); }); } else { - let basek_lsh: usize = basek - lsh; + let basek_lsh: usize = base2k - lsh; x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { - *c = get_carry(basek_lsh, *x, get_digit(basek_lsh, *x)); + *c = get_carry_i64(basek_lsh, *x, get_digit_i64(basek_lsh, *x)); }); } } #[inline(always)] -pub fn znx_normalize_first_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { +pub fn znx_normalize_first_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { assert!(x.len() <= carry.len()); - assert!(lsh < basek); + assert!(lsh < base2k); } if lsh == 0 { x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { - let digit: i64 = get_digit(basek, *x); - *c = get_carry(basek, *x, digit); + let digit: i64 = get_digit_i64(base2k, *x); + *c = get_carry_i64(base2k, *x, digit); *x = digit; }); } else { - let basek_lsh: usize = basek - lsh; + let basek_lsh: usize = base2k - lsh; x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { - let digit: i64 = get_digit(basek_lsh, *x); - *c = get_carry(basek_lsh, *x, digit); + let digit: i64 = get_digit_i64(basek_lsh, *x); + *c = get_carry_i64(basek_lsh, *x, digit); *x = digit << lsh; }); } } #[inline(always)] -pub fn znx_normalize_first_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { +pub fn znx_normalize_first_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { assert_eq!(x.len(), a.len()); assert!(x.len() <= carry.len()); - assert!(lsh < basek); + assert!(lsh < base2k); } if lsh == 0 { izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { - let digit: i64 = get_digit(basek, *a); - *c = get_carry(basek, *a, digit); + let digit: i64 = get_digit_i64(base2k, *a); + *c = get_carry_i64(base2k, *a, digit); *x = digit; }); } else { - let basek_lsh: usize = basek - lsh; + let basek_lsh: usize = base2k - lsh; izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { - let digit: i64 = get_digit(basek_lsh, *a); - *c = get_carry(basek_lsh, *a, digit); + let digit: i64 = get_digit_i64(basek_lsh, *a); + *c = get_carry_i64(basek_lsh, *a, digit); *x = digit << lsh; }); } } #[inline(always)] -pub fn znx_normalize_middle_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { +pub fn znx_normalize_middle_step_carry_only_ref(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { assert!(x.len() <= carry.len()); - assert!(lsh < basek); + assert!(lsh < base2k); } if lsh == 0 { x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { - let digit: i64 = get_digit(basek, *x); - let carry: i64 = get_carry(basek, *x, digit); + let digit: i64 = get_digit_i64(base2k, *x); + let carry: i64 = get_carry_i64(base2k, *x, digit); let digit_plus_c: i64 = digit + *c; - *c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c)); + *c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c)); }); } else { - let basek_lsh: usize = basek - lsh; + let basek_lsh: usize = base2k - lsh; x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { - let digit: i64 = get_digit(basek_lsh, *x); - let carry: i64 = get_carry(basek_lsh, *x, digit); + let digit: i64 = get_digit_i64(basek_lsh, *x); + let carry: i64 = get_carry_i64(basek_lsh, *x, digit); let digit_plus_c: i64 = (digit << lsh) + *c; - *c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c)); + *c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c)); }); } } #[inline(always)] -pub fn znx_normalize_middle_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { +pub fn znx_normalize_middle_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { assert!(x.len() <= carry.len()); - assert!(lsh < basek); + assert!(lsh < base2k); } + if lsh == 0 { x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { - let digit: i64 = get_digit(basek, *x); - let carry: i64 = get_carry(basek, *x, digit); + let digit: i64 = get_digit_i64(base2k, *x); + let carry: i64 = get_carry_i64(base2k, *x, digit); let digit_plus_c: i64 = digit + *c; - *x = get_digit(basek, digit_plus_c); - *c = carry + get_carry(basek, digit_plus_c, *x); + *x = get_digit_i64(base2k, digit_plus_c); + *c = carry + get_carry_i64(base2k, digit_plus_c, *x); }); } else { - let basek_lsh: usize = basek - lsh; + let basek_lsh: usize = base2k - lsh; x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { - let digit: i64 = get_digit(basek_lsh, *x); - let carry: i64 = get_carry(basek_lsh, *x, digit); + let digit: i64 = get_digit_i64(basek_lsh, *x); + let carry: i64 = get_carry_i64(basek_lsh, *x, digit); let digit_plus_c: i64 = (digit << lsh) + *c; - *x = get_digit(basek, digit_plus_c); - *c = carry + get_carry(basek, digit_plus_c, *x); + *x = get_digit_i64(base2k, digit_plus_c); + *c = carry + get_carry_i64(base2k, digit_plus_c, *x); }); } } #[inline(always)] -pub fn znx_normalize_middle_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { +pub fn znx_extract_digit_addmul_ref(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) { + for (r, s) in res.iter_mut().zip(src.iter_mut()) { + let digit: i64 = get_digit_i64(base2k, *s); + *s = get_carry_i64(base2k, *s, digit); + *r += digit << lsh; + } +} + +#[inline(always)] +pub fn znx_normalize_digit_ref(base2k: usize, res: &mut [i64], src: &mut [i64]) { + for (r, s) in res.iter_mut().zip(src.iter_mut()) { + let ri_digit: i64 = get_digit_i64(base2k, *r); + let ri_carry: i64 = get_carry_i64(base2k, *r, ri_digit); + *r = ri_digit; + *s += ri_carry; + } +} + +#[inline(always)] +pub fn znx_normalize_middle_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { assert_eq!(x.len(), a.len()); assert!(x.len() <= carry.len()); - assert!(lsh < basek); + assert!(lsh < base2k); } if lsh == 0 { izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { - let digit: i64 = get_digit(basek, *a); - let carry: i64 = get_carry(basek, *a, digit); + let digit: i64 = get_digit_i64(base2k, *a); + let carry: i64 = get_carry_i64(base2k, *a, digit); let digit_plus_c: i64 = digit + *c; - *x = get_digit(basek, digit_plus_c); - *c = carry + get_carry(basek, digit_plus_c, *x); + *x = get_digit_i64(base2k, digit_plus_c); + *c = carry + get_carry_i64(base2k, digit_plus_c, *x); }); } else { - let basek_lsh: usize = basek - lsh; + let basek_lsh: usize = base2k - lsh; izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { - let digit: i64 = get_digit(basek_lsh, *a); - let carry: i64 = get_carry(basek_lsh, *a, digit); + let digit: i64 = get_digit_i64(basek_lsh, *a); + let carry: i64 = get_carry_i64(basek_lsh, *a, digit); let digit_plus_c: i64 = (digit << lsh) + *c; - *x = get_digit(basek, digit_plus_c); - *c = carry + get_carry(basek, digit_plus_c, *x); + *x = get_digit_i64(base2k, digit_plus_c); + *c = carry + get_carry_i64(base2k, digit_plus_c, *x); }); } } #[inline(always)] -pub fn znx_normalize_final_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { +pub fn znx_normalize_final_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { assert!(x.len() <= carry.len()); - assert!(lsh < basek); + assert!(lsh < base2k); } if lsh == 0 { x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { - *x = get_digit(basek, get_digit(basek, *x) + *c); + *x = get_digit_i64(base2k, get_digit_i64(base2k, *x) + *c); }); } else { - let basek_lsh: usize = basek - lsh; + let basek_lsh: usize = base2k - lsh; x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { - *x = get_digit(basek, (get_digit(basek_lsh, *x) << lsh) + *c); + *x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *x) << lsh) + *c); }); } } #[inline(always)] -pub fn znx_normalize_final_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { +pub fn znx_normalize_final_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { #[cfg(debug_assertions)] { assert!(x.len() <= carry.len()); - assert!(lsh < basek); + assert!(lsh < base2k); } if lsh == 0 { izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { - *x = get_digit(basek, get_digit(basek, *a) + *c); + *x = get_digit_i64(base2k, get_digit_i64(base2k, *a) + *c); }); } else { - let basek_lsh: usize = basek - lsh; + let basek_lsh: usize = base2k - lsh; izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { - *x = get_digit(basek, (get_digit(basek_lsh, *a) << lsh) + *c); + *x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *a) << lsh) + *c); }); } } diff --git a/poulpy-hal/src/reference/znx/sampling.rs b/poulpy-hal/src/reference/znx/sampling.rs index feaa393..0646d44 100644 --- a/poulpy-hal/src/reference/znx/sampling.rs +++ b/poulpy-hal/src/reference/znx/sampling.rs @@ -2,8 +2,8 @@ use rand_distr::{Distribution, Normal}; use crate::source::Source; -pub fn znx_fill_uniform_ref(basek: usize, res: &mut [i64], source: &mut Source) { - let pow2k: u64 = 1 << basek; +pub fn znx_fill_uniform_ref(base2k: usize, res: &mut [i64], source: &mut Source) { + let pow2k: u64 = 1 << base2k; let mask: u64 = pow2k - 1; let pow2k_half: i64 = (pow2k >> 1) as i64; res.iter_mut() diff --git a/poulpy-hal/src/reference/znx/sub.rs b/poulpy-hal/src/reference/znx/sub.rs index 7cb4599..06193a6 100644 --- a/poulpy-hal/src/reference/znx/sub.rs +++ b/poulpy-hal/src/reference/znx/sub.rs @@ -11,7 +11,7 @@ pub fn znx_sub_ref(res: &mut [i64], a: &[i64], b: &[i64]) { } } -pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) { +pub fn znx_sub_inplace_ref(res: &mut [i64], a: &[i64]) { #[cfg(debug_assertions)] { assert_eq!(res.len(), a.len()); @@ -23,7 +23,7 @@ pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) { } } -pub fn znx_sub_ba_inplace_ref(res: &mut [i64], a: &[i64]) { +pub fn znx_sub_negate_inplace_ref(res: &mut [i64], a: &[i64]) { #[cfg(debug_assertions)] { assert_eq!(res.len(), a.len()); diff --git a/poulpy-hal/src/source.rs b/poulpy-hal/src/source.rs index 4852b31..71c38f6 100644 --- a/poulpy-hal/src/source.rs +++ b/poulpy-hal/src/source.rs @@ -48,6 +48,16 @@ impl Source { pub fn next_i64(&mut self) -> i64 { self.next_u64() as i64 } + + #[inline(always)] + pub fn next_i128(&mut self) -> i128 { + self.next_u128() as i128 + } + + #[inline(always)] + pub fn next_u128(&mut self) -> u128 { + (self.next_u64() as u128) << 64 | (self.next_u64() as u128) + } } impl RngCore for Source { diff --git a/poulpy-hal/src/test_suite/mod.rs b/poulpy-hal/src/test_suite/mod.rs index bbdaf05..f31c856 100644 --- a/poulpy-hal/src/test_suite/mod.rs +++ b/poulpy-hal/src/test_suite/mod.rs @@ -41,7 +41,7 @@ macro_rules! cross_backend_test_suite { backend_ref = $backend_ref:ty, backend_test = $backend_test:ty, size = $size:expr, - basek = $basek:expr, + base2k = $base2k:expr, tests = { $( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)? } @@ -60,7 +60,7 @@ macro_rules! cross_backend_test_suite { $(#[$attr])* #[test] fn $test_name() { - ($impl)($basek, &*MODULE_REF, &*MODULE_TEST); + ($impl)($base2k, &*MODULE_REF, &*MODULE_TEST); } )+ } diff --git a/poulpy-hal/src/test_suite/serialization.rs b/poulpy-hal/src/test_suite/serialization.rs index f2f9ee7..dacd656 100644 --- a/poulpy-hal/src/test_suite/serialization.rs +++ b/poulpy-hal/src/test_suite/serialization.rs @@ -1,7 +1,7 @@ use std::fmt::Debug; use crate::{ - layouts::{FillUniform, ReaderFrom, Reset, WriterTo}, + layouts::{FillUniform, ReaderFrom, WriterTo}, source::Source, }; @@ -10,7 +10,7 @@ use crate::{ /// - `T` must implement I/O traits, zeroing, cloning, and random filling. pub fn test_reader_writer_interface(mut original: T) where - T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + Reset + FillUniform, + T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + FillUniform, { // Fill original with uniform random data let mut source = Source::new([0u8; 32]); @@ -20,9 +20,9 @@ where let mut buffer = Vec::new(); original.write_to(&mut buffer).expect("write_to failed"); - // Prepare receiver: same shape, but zeroed + // Prepare receiver: same shape, but randomized let mut receiver = original.clone(); - receiver.reset(); + receiver.fill_uniform(50, &mut source); // Deserialize from buffer let mut reader: &[u8] = &buffer; diff --git a/poulpy-hal/src/test_suite/svp.rs b/poulpy-hal/src/test_suite/svp.rs index ec3a649..e72dc5a 100644 --- a/poulpy-hal/src/test_suite/svp.rs +++ b/poulpy-hal/src/test_suite/svp.rs @@ -10,7 +10,7 @@ use crate::{ source::Source, }; -pub fn test_svp_apply_dft(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_svp_apply_dft(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: SvpPrepare
+ SvpApplyDft
@@ -40,7 +40,7 @@ where let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); let mut scalar: ScalarZnx> = ScalarZnx::alloc(n, cols); - scalar.fill_uniform(basek, &mut source); + scalar.fill_uniform(base2k, &mut source); let scalar_digest: u64 = scalar.digest_u64(); @@ -60,7 +60,7 @@ where for a_size in [1, 2, 3, 4] { // Create a random input VecZnx let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); @@ -91,17 +91,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -113,7 +115,7 @@ where } } -pub fn test_svp_apply_dft_to_dft(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_svp_apply_dft_to_dft(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: SvpPrepare
+ SvpApplyDftToDft
@@ -145,7 +147,7 @@ where let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); let mut scalar: ScalarZnx> = ScalarZnx::alloc(n, cols); - scalar.fill_uniform(basek, &mut source); + scalar.fill_uniform(base2k, &mut source); let scalar_digest: u64 = scalar.digest_u64(); @@ -165,7 +167,7 @@ where for a_size in [3] { // Create a random input VecZnx let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); @@ -211,17 +213,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -233,7 +237,7 @@ where } } -pub fn test_svp_apply_dft_to_dft_add(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_svp_apply_dft_to_dft_add(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: SvpPrepare
+ SvpApplyDftToDftAdd
@@ -265,7 +269,7 @@ where let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); let mut scalar: ScalarZnx> = ScalarZnx::alloc(n, cols); - scalar.fill_uniform(basek, &mut source); + scalar.fill_uniform(base2k, &mut source); let scalar_digest: u64 = scalar.digest_u64(); @@ -285,7 +289,7 @@ where for a_size in [1, 2, 3, 4] { // Create a random input VecZnx let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); @@ -302,7 +306,7 @@ where for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); @@ -336,17 +340,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -359,7 +365,7 @@ where } pub fn test_svp_apply_dft_to_dft_inplace( - basek: usize, + base2k: usize, module_ref: &Module
, module_test: &Module, ) where @@ -393,7 +399,7 @@ pub fn test_svp_apply_dft_to_dft_inplace( let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); let mut scalar: ScalarZnx> = ScalarZnx::alloc(n, cols); - scalar.fill_uniform(basek, &mut source); + scalar.fill_uniform(base2k, &mut source); let scalar_digest: u64 = scalar.digest_u64(); @@ -412,7 +418,7 @@ pub fn test_svp_apply_dft_to_dft_inplace( for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let res_digest: u64 = res.digest_u64(); let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); @@ -442,17 +448,19 @@ pub fn test_svp_apply_dft_to_dft_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), diff --git a/poulpy-hal/src/test_suite/vec_znx.rs b/poulpy-hal/src/test_suite/vec_znx.rs index f5d180e..f1c4e86 100644 --- a/poulpy-hal/src/test_suite/vec_znx.rs +++ b/poulpy-hal/src/test_suite/vec_znx.rs @@ -8,38 +8,18 @@ use crate::{ VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes, - VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, + VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, }, layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut}, reference::znx::znx_copy_ref, source::Source, }; -pub fn test_vec_znx_encode_vec_i64_lo_norm() { +pub fn test_vec_znx_encode_vec_i64() { let n: usize = 32; - let basek: usize = 17; + let base2k: usize = 17; let size: usize = 5; - let k: usize = size * basek - 5; - let mut a: VecZnx> = VecZnx::alloc(n, 2, size); - let mut source: Source = Source::new([0u8; 32]); - let raw: &mut [i64] = a.raw_mut(); - raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut() - .for_each(|x| *x = (source.next_i64() << 56) >> 56); - a.encode_vec_i64(basek, col_i, k, &have, 10); - let mut want: Vec = vec![i64::default(); n]; - a.decode_vec_i64(basek, col_i, k, &mut want); - assert_eq!(have, want, "{:?} != {:?}", &have, &want); - }); -} - -pub fn test_vec_znx_encode_vec_i64_hi_norm() { - let n: usize = 32; - let basek: usize = 17; - let size: usize = 5; - for k in [1, basek / 2, size * basek - 5] { + for k in [1, base2k / 2, size * base2k - 5] { let mut a: VecZnx> = VecZnx::alloc(n, 2, size); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); @@ -53,15 +33,15 @@ pub fn test_vec_znx_encode_vec_i64_hi_norm() { *x = source.next_i64(); } }); - a.encode_vec_i64(basek, col_i, k, &have, 63); + a.encode_vec_i64(base2k, col_i, k, &have); let mut want: Vec = vec![i64::default(); n]; - a.decode_vec_i64(basek, col_i, k, &mut want); + a.decode_vec_i64(base2k, col_i, k, &mut want); assert_eq!(have, want, "{:?} != {:?}", &have, &want); }) } } -pub fn test_vec_znx_add_scalar(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_add_scalar(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxAddScalar, Module: VecZnxAddScalar, @@ -74,12 +54,12 @@ where let cols: usize = 2; let mut a: ScalarZnx> = ScalarZnx::alloc(n, cols); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest = a.digest_u64(); for a_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, a_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -87,8 +67,8 @@ where let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); // Set d to garbage - rest_ref.fill_uniform(basek, &mut source); - res_test.fill_uniform(basek, &mut source); + rest_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); // Reference for i in 0..cols { @@ -103,7 +83,7 @@ where } } -pub fn test_vec_znx_add_scalar_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_add_scalar_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxAddScalarInplace, Module: VecZnxAddScalarInplace, @@ -116,14 +96,14 @@ where let cols: usize = 2; let mut b: ScalarZnx> = ScalarZnx::alloc(n, cols); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); for res_size in [1, 2, 3, 4] { let mut rest_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - rest_ref.fill_uniform(basek, &mut source); + rest_ref.fill_uniform(base2k, &mut source); res_test.raw_mut().copy_from_slice(rest_ref.raw()); for i in 0..cols { @@ -135,7 +115,7 @@ where assert_eq!(rest_ref, res_test); } } -pub fn test_vec_znx_add(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_add(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxAdd, Module: VecZnxAdd, @@ -148,13 +128,13 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); @@ -163,8 +143,8 @@ where let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); // Set d to garbage - res_ref.fill_uniform(basek, &mut source); - res_test.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); // Reference for i in 0..cols { @@ -181,7 +161,7 @@ where } } -pub fn test_vec_znx_add_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_add_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxAddInplace, Module: VecZnxAddInplace, @@ -194,14 +174,14 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); res_test.raw_mut().copy_from_slice(res_ref.raw()); for i in 0..cols { @@ -215,7 +195,7 @@ where } } -pub fn test_vec_znx_automorphism(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_automorphism(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxAutomorphism, Module: VecZnxAutomorphism, @@ -228,7 +208,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -261,7 +241,7 @@ where } pub fn test_vec_znx_automorphism_inplace( - basek: usize, + base2k: usize, module_ref: &Module
, module_test: &Module, ) where @@ -284,7 +264,7 @@ pub fn test_vec_znx_automorphism_inplace( let mut res_test: VecZnx> = VecZnx::alloc(n, cols, size); // Fill a with random i64 - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); znx_copy_ref(res_test.raw_mut(), res_ref.raw()); let p: i64 = -7; @@ -309,7 +289,7 @@ pub fn test_vec_znx_automorphism_inplace( } } -pub fn test_vec_znx_copy(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_copy(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxCopy, Module: VecZnxCopy, @@ -322,7 +302,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -330,8 +310,8 @@ where let mut res_1: VecZnx> = VecZnx::alloc(n, cols, res_size); // Set d to garbage - res_0.fill_uniform(basek, &mut source); - res_1.fill_uniform(basek, &mut source); + res_0.fill_uniform(base2k, &mut source); + res_1.fill_uniform(base2k, &mut source); // Reference for i in 0..cols { @@ -345,7 +325,7 @@ where } } -pub fn test_vec_znx_merge_rings(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_merge_rings(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxMergeRings
+ ModuleNew
+ VecZnxMergeRingsTmpBytes, Module: VecZnxMergeRings + ModuleNew + VecZnxMergeRingsTmpBytes, @@ -367,7 +347,7 @@ where ]; a.iter_mut().for_each(|ai| { - ai.fill_uniform(basek, &mut source); + ai.fill_uniform(base2k, &mut source); }); let a_digests: [u64; 2] = [a[0].digest_u64(), a[1].digest_u64()]; @@ -376,8 +356,8 @@ where let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_ref.fill_uniform(basek, &mut source); - res_test.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); for i in 0..cols { module_ref.vec_znx_merge_rings(&mut res_test, i, &a, i, scratch_ref.borrow()); @@ -390,7 +370,7 @@ where } } -pub fn test_vec_znx_mul_xp_minus_one(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_mul_xp_minus_one(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxMulXpMinusOne, Module: VecZnxMulXpMinusOne, @@ -403,7 +383,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); @@ -437,7 +417,7 @@ where } pub fn test_vec_znx_mul_xp_minus_one_inplace( - basek: usize, + base2k: usize, module_ref: &Module
, module_test: &Module, ) where @@ -460,7 +440,7 @@ pub fn test_vec_znx_mul_xp_minus_one_inplace( let mut res_test: VecZnx> = VecZnx::alloc(n, cols, size); // Fill a with random i64 - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); znx_copy_ref(res_test.raw_mut(), res_ref.raw()); let p: i64 = -7; @@ -483,7 +463,7 @@ pub fn test_vec_znx_mul_xp_minus_one_inplace( } } -pub fn test_vec_znx_negate(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_negate(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxNegate, Module: VecZnxNegate, @@ -496,14 +476,14 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); res_test.raw_mut().copy_from_slice(res_ref.raw()); for i in 0..cols { @@ -517,7 +497,7 @@ where } } -pub fn test_vec_znx_negate_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_negate_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxNegateInplace, Module: VecZnxNegateInplace, @@ -532,7 +512,7 @@ where let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); res_test.raw_mut().copy_from_slice(res_ref.raw()); for i in 0..cols { @@ -544,7 +524,7 @@ where } } -pub fn test_vec_znx_normalize(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_normalize(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxNormalize
+ VecZnxNormalizeTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, @@ -562,7 +542,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -570,13 +550,21 @@ where let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); // Set d to garbage - res_ref.fill_uniform(basek, &mut source); - res_test.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); // Reference for i in 0..cols { - module_ref.vec_znx_normalize(basek, &mut res_ref, i, &a, i, scratch_ref.borrow()); - module_test.vec_znx_normalize(basek, &mut res_test, i, &a, i, scratch_test.borrow()); + module_ref.vec_znx_normalize(base2k, &mut res_ref, i, base2k, &a, i, scratch_ref.borrow()); + module_test.vec_znx_normalize( + base2k, + &mut res_test, + i, + base2k, + &a, + i, + scratch_test.borrow(), + ); } assert_eq!(a.digest_u64(), a_digest); @@ -585,7 +573,7 @@ where } } -pub fn test_vec_znx_normalize_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_normalize_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxNormalizeInplace
+ VecZnxNormalizeTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, @@ -605,20 +593,20 @@ where let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); res_test.raw_mut().copy_from_slice(res_ref.raw()); // Reference for i in 0..cols { - module_ref.vec_znx_normalize_inplace(basek, &mut res_ref, i, scratch_ref.borrow()); - module_test.vec_znx_normalize_inplace(basek, &mut res_test, i, scratch_test.borrow()); + module_ref.vec_znx_normalize_inplace(base2k, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_normalize_inplace(base2k, &mut res_test, i, scratch_test.borrow()); } assert_eq!(res_ref, res_test); } } -pub fn test_vec_znx_rotate(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_rotate(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxRotate, Module: VecZnxRotate, @@ -631,7 +619,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -663,7 +651,7 @@ where } } -pub fn test_vec_znx_rotate_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_rotate_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxRotateInplace
+ VecZnxRotateInplaceTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, @@ -684,7 +672,7 @@ where let mut res_test: VecZnx> = VecZnx::alloc(n, cols, size); // Fill a with random i64 - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); znx_copy_ref(res_test.raw_mut(), res_ref.raw()); let p: i64 = -5; @@ -714,7 +702,7 @@ where Module: VecZnxFillUniform, { let n: usize = module.n(); - let basek: usize = 17; + let base2k: usize = 17; let size: usize = 5; let mut source: Source = Source::new([0u8; 32]); let cols: usize = 2; @@ -722,19 +710,17 @@ where let one_12_sqrt: f64 = 0.28867513459481287; (0..cols).for_each(|col_i| { let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); - module.vec_znx_fill_uniform(basek, &mut a, col_i, &mut source); + module.vec_znx_fill_uniform(base2k, &mut a, col_i, &mut source); (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { assert_eq!(a.at(col_j, limb_i), zero); }) } else { - let std: f64 = a.std(basek, col_i); + let std: f64 = a.std(base2k, col_i); assert!( (std - one_12_sqrt).abs() < 0.01, - "std={} ~!= {}", - std, - one_12_sqrt + "std={std} ~!= {one_12_sqrt}", ); } }) @@ -746,7 +732,7 @@ where Module: VecZnxFillNormal, { let n: usize = module.n(); - let basek: usize = 17; + let base2k: usize = 17; let k: usize = 2 * 17; let size: usize = 5; let sigma: f64 = 3.2; @@ -757,15 +743,15 @@ where let k_f64: f64 = (1u64 << k as u64) as f64; (0..cols).for_each(|col_i| { let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); - module.vec_znx_fill_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); + module.vec_znx_fill_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { assert_eq!(a.at(col_j, limb_i), zero); }) } else { - let std: f64 = a.std(basek, col_i) * k_f64; - assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma); + let std: f64 = a.std(base2k, col_i) * k_f64; + assert!((std - sigma).abs() < 0.1, "std={std} ~!= {sigma}"); } }) }); @@ -776,7 +762,7 @@ where Module: VecZnxFillNormal + VecZnxAddNormal, { let n: usize = module.n(); - let basek: usize = 17; + let base2k: usize = 17; let k: usize = 2 * 17; let size: usize = 5; let sigma: f64 = 3.2; @@ -788,19 +774,18 @@ where let sqrt2: f64 = SQRT_2; (0..cols).for_each(|col_i| { let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); - module.vec_znx_fill_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); - module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); + module.vec_znx_fill_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound); + module.vec_znx_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { (0..size).for_each(|limb_i| { assert_eq!(a.at(col_j, limb_i), zero); }) } else { - let std: f64 = a.std(basek, col_i) * k_f64; + let std: f64 = a.std(base2k, col_i) * k_f64; assert!( (std - sigma * sqrt2).abs() < 0.1, - "std={} ~!= {}", - std, + "std={std} ~!= {}", sigma * sqrt2 ); } @@ -808,7 +793,7 @@ where }); } -pub fn test_vec_znx_lsh(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_lsh(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxLsh
+ VecZnxLshTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, @@ -826,22 +811,22 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { - for k in 0..res_size * basek { + for k in 0..res_size * base2k { let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); // Set d to garbage - res_ref.fill_uniform(basek, &mut source); - res_test.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); // Reference for i in 0..cols { - module_ref.vec_znx_lsh(basek, k, &mut res_ref, i, &a, i, scratch_ref.borrow()); - module_test.vec_znx_lsh(basek, k, &mut res_test, i, &a, i, scratch_test.borrow()); + module_ref.vec_znx_lsh(base2k, k, &mut res_ref, i, &a, i, scratch_ref.borrow()); + module_test.vec_znx_lsh(base2k, k, &mut res_test, i, &a, i, scratch_test.borrow()); } assert_eq!(a.digest_u64(), a_digest); @@ -851,7 +836,7 @@ where } } -pub fn test_vec_znx_lsh_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_lsh_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxLshInplace
+ VecZnxLshTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, @@ -868,16 +853,16 @@ where let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes()); for res_size in [1, 2, 3, 4] { - for k in 0..basek * res_size { + for k in 0..base2k * res_size { let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); res_test.raw_mut().copy_from_slice(res_ref.raw()); for i in 0..cols { - module_ref.vec_znx_lsh_inplace(basek, k, &mut res_ref, i, scratch_ref.borrow()); - module_test.vec_znx_lsh_inplace(basek, k, &mut res_test, i, scratch_test.borrow()); + module_ref.vec_znx_lsh_inplace(base2k, k, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_lsh_inplace(base2k, k, &mut res_test, i, scratch_test.borrow()); } assert_eq!(res_ref, res_test); @@ -885,7 +870,7 @@ where } } -pub fn test_vec_znx_rsh(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_rsh(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxRsh
+ VecZnxLshTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, @@ -902,22 +887,22 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { - for k in 0..res_size * basek { + for k in 0..res_size * base2k { let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); // Set d to garbage - res_ref.fill_uniform(basek, &mut source); - res_test.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); // Reference for i in 0..cols { - module_ref.vec_znx_rsh(basek, k, &mut res_ref, i, &a, i, scratch_ref.borrow()); - module_test.vec_znx_rsh(basek, k, &mut res_test, i, &a, i, scratch_test.borrow()); + module_ref.vec_znx_rsh(base2k, k, &mut res_ref, i, &a, i, scratch_ref.borrow()); + module_test.vec_znx_rsh(base2k, k, &mut res_test, i, &a, i, scratch_test.borrow()); } assert_eq!(a.digest_u64(), a_digest); @@ -927,7 +912,7 @@ where } } -pub fn test_vec_znx_rsh_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_rsh_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxRshInplace
+ VecZnxLshTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, @@ -943,16 +928,16 @@ where let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes()); for res_size in [1, 2, 3, 4] { - for k in 0..basek * res_size { + for k in 0..base2k * res_size { let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); res_test.raw_mut().copy_from_slice(res_ref.raw()); for i in 0..cols { - module_ref.vec_znx_rsh_inplace(basek, k, &mut res_ref, i, scratch_ref.borrow()); - module_test.vec_znx_rsh_inplace(basek, k, &mut res_test, i, scratch_test.borrow()); + module_ref.vec_znx_rsh_inplace(base2k, k, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_rsh_inplace(base2k, k, &mut res_test, i, scratch_test.borrow()); } assert_eq!(res_ref, res_test); @@ -960,7 +945,7 @@ where } } -pub fn test_vec_znx_split_ring(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_split_ring(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxSplitRing
+ ModuleNew
+ VecZnxSplitRingTmpBytes, ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, @@ -977,7 +962,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest = a.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -992,11 +977,11 @@ where ]; res_ref.iter_mut().for_each(|ri| { - ri.fill_uniform(basek, &mut source); + ri.fill_uniform(base2k, &mut source); }); res_test.iter_mut().for_each(|ri| { - ri.fill_uniform(basek, &mut source); + ri.fill_uniform(base2k, &mut source); }); for i in 0..cols { @@ -1013,7 +998,7 @@ where } } -pub fn test_vec_znx_sub_scalar(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_sub_scalar(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxSubScalar, Module: VecZnxSubScalar, @@ -1025,12 +1010,12 @@ where let cols: usize = 2; let mut a: ScalarZnx> = ScalarZnx::alloc(n, cols); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -1038,8 +1023,8 @@ where let mut res_1: VecZnx> = VecZnx::alloc(n, cols, res_size); // Set d to garbage - res_0.fill_uniform(basek, &mut source); - res_1.fill_uniform(basek, &mut source); + res_0.fill_uniform(base2k, &mut source); + res_1.fill_uniform(base2k, &mut source); // Reference for i in 0..cols { @@ -1054,7 +1039,7 @@ where } } -pub fn test_vec_znx_sub_scalar_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_sub_scalar_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxSubScalarInplace, Module: VecZnxSubScalarInplace, @@ -1066,14 +1051,14 @@ where let cols: usize = 2; let mut a: ScalarZnx> = ScalarZnx::alloc(n, cols); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { let mut res_0: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_1: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_0.fill_uniform(basek, &mut source); + res_0.fill_uniform(base2k, &mut source); res_1.raw_mut().copy_from_slice(res_0.raw()); for i in 0..cols { @@ -1086,7 +1071,7 @@ where } } -pub fn test_vec_znx_sub(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_sub(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxSub, Module: VecZnxSub, @@ -1099,12 +1084,12 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -1112,8 +1097,8 @@ where let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); // Set d to garbage - res_ref.fill_uniform(basek, &mut source); - res_test.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); // Reference for i in 0..cols { @@ -1130,10 +1115,10 @@ where } } -pub fn test_vec_znx_sub_ab_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_sub_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where - Module
: VecZnxSubABInplace, - Module: VecZnxSubABInplace, + Module
: VecZnxSubInplace, + Module: VecZnxSubInplace, { assert_eq!(module_ref.n(), module_test.n()); let n: usize = module_ref.n(); @@ -1143,19 +1128,19 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); res_test.raw_mut().copy_from_slice(res_ref.raw()); for i in 0..cols { - module_test.vec_znx_sub_ab_inplace(&mut res_ref, i, &a, i); - module_ref.vec_znx_sub_ab_inplace(&mut res_test, i, &a, i); + module_test.vec_znx_sub_inplace(&mut res_ref, i, &a, i); + module_ref.vec_znx_sub_inplace(&mut res_test, i, &a, i); } assert_eq!(a.digest_u64(), a_digest); @@ -1164,10 +1149,10 @@ where } } -pub fn test_vec_znx_sub_ba_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_sub_negate_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where - Module
: VecZnxSubBAInplace, - Module: VecZnxSubBAInplace, + Module
: VecZnxSubNegateInplace, + Module: VecZnxSubNegateInplace, { assert_eq!(module_ref.n(), module_test.n()); let n: usize = module_ref.n(); @@ -1177,19 +1162,19 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); - res_ref.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); res_test.raw_mut().copy_from_slice(res_ref.raw()); for i in 0..cols { - module_test.vec_znx_sub_ba_inplace(&mut res_ref, i, &a, i); - module_ref.vec_znx_sub_ba_inplace(&mut res_test, i, &a, i); + module_test.vec_znx_sub_negate_inplace(&mut res_ref, i, &a, i); + module_ref.vec_znx_sub_negate_inplace(&mut res_test, i, &a, i); } assert_eq!(a.digest_u64(), a_digest); @@ -1198,7 +1183,7 @@ where } } -pub fn test_vec_znx_switch_ring(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_switch_ring(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxSwitchRing, Module: VecZnxSwitchRing, @@ -1213,7 +1198,7 @@ where let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); // Fill a with random i64 - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -1221,8 +1206,8 @@ where let mut res_ref: VecZnx> = VecZnx::alloc(n << 1, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n << 1, cols, res_size); - res_ref.fill_uniform(basek, &mut source); - res_test.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); // Normalize on c for i in 0..cols { @@ -1238,8 +1223,8 @@ where let mut res_ref: VecZnx> = VecZnx::alloc(n >> 1, cols, res_size); let mut res_test: VecZnx> = VecZnx::alloc(n >> 1, cols, res_size); - res_ref.fill_uniform(basek, &mut source); - res_test.fill_uniform(basek, &mut source); + res_ref.fill_uniform(base2k, &mut source); + res_test.fill_uniform(base2k, &mut source); // Normalize on c for i in 0..cols { diff --git a/poulpy-hal/src/test_suite/vec_znx_big.rs b/poulpy-hal/src/test_suite/vec_znx_big.rs index d888403..2de43ad 100644 --- a/poulpy-hal/src/test_suite/vec_znx_big.rs +++ b/poulpy-hal/src/test_suite/vec_znx_big.rs @@ -5,14 +5,14 @@ use crate::{ ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace, - VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace, + VecZnxBigSub, VecZnxBigSubInplace, VecZnxBigSubNegateInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallB, + VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, }, layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig}, source::Source, }; -pub fn test_vec_znx_big_add(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_add(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigAdd
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, @@ -32,7 +32,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest = a.digest_u64(); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); @@ -50,7 +50,7 @@ where for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest = b.digest_u64(); let mut b_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, b_size); @@ -93,17 +93,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -119,7 +121,7 @@ where } } -pub fn test_vec_znx_big_add_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_add_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigAddInplace
+ VecZnxBigAlloc
@@ -145,7 +147,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); @@ -160,7 +162,7 @@ where for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); @@ -186,17 +188,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -211,7 +215,7 @@ where } } -pub fn test_vec_znx_big_add_small(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_add_small(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigAddSmall
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, @@ -231,7 +235,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); @@ -246,7 +250,7 @@ where for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -275,17 +279,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -302,7 +308,7 @@ where } pub fn test_vec_znx_big_add_small_inplace( - basek: usize, + base2k: usize, module_ref: &Module
, module_test: &Module, ) where @@ -330,13 +336,13 @@ pub fn test_vec_znx_big_add_small_inplace( for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); @@ -361,17 +367,19 @@ pub fn test_vec_znx_big_add_small_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -386,7 +394,7 @@ pub fn test_vec_znx_big_add_small_inplace( } } -pub fn test_vec_znx_big_automorphism(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_automorphism(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigAutomorphism
+ VecZnxBigAlloc
@@ -412,7 +420,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); @@ -451,17 +459,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -478,7 +488,7 @@ where } pub fn test_vec_znx_big_automorphism_inplace( - basek: usize, + base2k: usize, module_ref: &Module
, module_test: &Module, ) where @@ -512,7 +522,7 @@ pub fn test_vec_znx_big_automorphism_inplace( for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); @@ -536,17 +546,19 @@ pub fn test_vec_znx_big_automorphism_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -561,7 +573,7 @@ pub fn test_vec_znx_big_automorphism_inplace( } } -pub fn test_vec_znx_big_negate(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_negate(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigNegate
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, @@ -581,7 +593,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); @@ -619,17 +631,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -644,7 +658,7 @@ where } } -pub fn test_vec_znx_big_negate_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_negate_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigNegateInplace
+ VecZnxBigAlloc
@@ -672,7 +686,7 @@ where for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); @@ -695,17 +709,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -719,7 +735,7 @@ where } } -pub fn test_vec_znx_big_normalize(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_normalize(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigAlloc
+ VecZnxBigFromSmall
@@ -772,8 +788,24 @@ where // Reference for j in 0..cols { - module_ref.vec_znx_big_normalize(basek, &mut res_ref, j, &a_ref, j, scratch_ref.borrow()); - module_test.vec_znx_big_normalize(basek, &mut res_test, j, &a_test, j, scratch_test.borrow()); + module_ref.vec_znx_big_normalize( + base2k, + &mut res_ref, + j, + base2k, + &a_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + base2k, + &mut res_test, + j, + base2k, + &a_test, + j, + scratch_test.borrow(), + ); } assert_eq!(a_ref.digest_u64(), a_ref_digest); @@ -784,7 +816,7 @@ where } } -pub fn test_vec_znx_big_sub(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_sub(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigSub
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, @@ -804,7 +836,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); @@ -819,7 +851,7 @@ where for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let mut b_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, b_size); let mut b_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, b_size); @@ -859,17 +891,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -885,14 +919,14 @@ where } } -pub fn test_vec_znx_big_sub_ab_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_sub_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where - Module
: VecZnxBigSubABInplace
+ Module
: VecZnxBigSubInplace
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, - Module: VecZnxBigSubABInplace + Module: VecZnxBigSubInplace + VecZnxBigAlloc + VecZnxBigFromSmall + VecZnxBigNormalize @@ -911,7 +945,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); @@ -926,7 +960,7 @@ where for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); @@ -937,8 +971,8 @@ where } for i in 0..cols { - module_ref.vec_znx_big_sub_ab_inplace(&mut res_big_ref, i, &a_ref, i); - module_test.vec_znx_big_sub_ab_inplace(&mut res_big_test, i, &a_test, i); + module_ref.vec_znx_big_sub_inplace(&mut res_big_ref, i, &a_ref, i); + module_test.vec_znx_big_sub_inplace(&mut res_big_test, i, &a_test, i); } assert_eq!(a_ref.digest_u64(), a_ref_digest); @@ -952,17 +986,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -977,14 +1013,17 @@ where } } -pub fn test_vec_znx_big_sub_ba_inplace(basek: usize, module_ref: &Module
, module_test: &Module) -where - Module
: VecZnxBigSubBAInplace
+pub fn test_vec_znx_big_sub_negate_inplace( + base2k: usize, + module_ref: &Module
, + module_test: &Module, +) where + Module
: VecZnxBigSubNegateInplace
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, - Module: VecZnxBigSubBAInplace + Module: VecZnxBigSubNegateInplace + VecZnxBigAlloc + VecZnxBigFromSmall + VecZnxBigNormalize @@ -1003,7 +1042,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); @@ -1018,7 +1057,7 @@ where for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); @@ -1029,8 +1068,8 @@ where } for i in 0..cols { - module_ref.vec_znx_big_sub_ba_inplace(&mut res_big_ref, i, &a_ref, i); - module_test.vec_znx_big_sub_ba_inplace(&mut res_big_test, i, &a_test, i); + module_ref.vec_znx_big_sub_negate_inplace(&mut res_big_ref, i, &a_ref, i); + module_test.vec_znx_big_sub_negate_inplace(&mut res_big_test, i, &a_test, i); } assert_eq!(a_ref.digest_u64(), a_ref_digest); @@ -1044,17 +1083,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -1069,7 +1110,7 @@ where } } -pub fn test_vec_znx_big_sub_small_a(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_sub_small_a(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigSubSmallA
+ VecZnxBigAlloc
@@ -1095,7 +1136,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); @@ -1110,7 +1151,7 @@ where for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -1139,17 +1180,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -1165,7 +1208,7 @@ where } } -pub fn test_vec_znx_big_sub_small_b(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_big_sub_small_b(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxBigSubSmallB
+ VecZnxBigAlloc
@@ -1191,7 +1234,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); @@ -1206,7 +1249,7 @@ where for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -1235,17 +1278,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -1262,16 +1307,16 @@ where } pub fn test_vec_znx_big_sub_small_a_inplace( - basek: usize, + base2k: usize, module_ref: &Module
, module_test: &Module, ) where - Module
: VecZnxBigSubSmallAInplace
+ Module
: VecZnxBigSubSmallInplace
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, - Module: VecZnxBigSubSmallAInplace + Module: VecZnxBigSubSmallInplace + VecZnxBigAlloc + VecZnxBigFromSmall + VecZnxBigNormalize @@ -1290,13 +1335,13 @@ pub fn test_vec_znx_big_sub_small_a_inplace( for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); @@ -1307,8 +1352,8 @@ pub fn test_vec_znx_big_sub_small_a_inplace( } for i in 0..cols { - module_ref.vec_znx_big_sub_small_a_inplace(&mut res_big_ref, i, &a, i); - module_test.vec_znx_big_sub_small_a_inplace(&mut res_big_test, i, &a, i); + module_ref.vec_znx_big_sub_small_inplace(&mut res_big_ref, i, &a, i); + module_test.vec_znx_big_sub_small_inplace(&mut res_big_test, i, &a, i); } assert_eq!(a.digest_u64(), a_digest); @@ -1321,17 +1366,19 @@ pub fn test_vec_znx_big_sub_small_a_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -1347,16 +1394,16 @@ pub fn test_vec_znx_big_sub_small_a_inplace( } pub fn test_vec_znx_big_sub_small_b_inplace( - basek: usize, + base2k: usize, module_ref: &Module
, module_test: &Module, ) where - Module
: VecZnxBigSubSmallBInplace
+ Module
: VecZnxBigSubSmallNegateInplace
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, - Module: VecZnxBigSubSmallBInplace + Module: VecZnxBigSubSmallNegateInplace + VecZnxBigAlloc + VecZnxBigFromSmall + VecZnxBigNormalize @@ -1375,13 +1422,13 @@ pub fn test_vec_znx_big_sub_small_b_inplace( for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); @@ -1392,8 +1439,8 @@ pub fn test_vec_znx_big_sub_small_b_inplace( } for i in 0..cols { - module_ref.vec_znx_big_sub_small_b_inplace(&mut res_big_ref, i, &a, i); - module_test.vec_znx_big_sub_small_b_inplace(&mut res_big_test, i, &a, i); + module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i); + module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i); } assert_eq!(a.digest_u64(), a_digest); @@ -1406,17 +1453,19 @@ pub fn test_vec_znx_big_sub_small_b_inplace( for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), diff --git a/poulpy-hal/src/test_suite/vec_znx_dft.rs b/poulpy-hal/src/test_suite/vec_znx_dft.rs index 674f1d9..87e2992 100644 --- a/poulpy-hal/src/test_suite/vec_znx_dft.rs +++ b/poulpy-hal/src/test_suite/vec_znx_dft.rs @@ -3,14 +3,14 @@ use rand::RngCore; use crate::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAdd, - VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftCopy, VecZnxDftSub, VecZnxDftSubABInplace, - VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, + VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftCopy, VecZnxDftSub, VecZnxDftSubInplace, + VecZnxDftSubNegateInplace, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, }, layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft}, source::Source, }; -pub fn test_vec_znx_dft_add(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_dft_add(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxDftAdd
+ VecZnxDftAlloc
@@ -38,7 +38,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest = a.digest_u64(); let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); @@ -56,7 +56,7 @@ where for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); let mut b_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, b_size); @@ -102,17 +102,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -128,7 +130,7 @@ where } } -pub fn test_vec_znx_dft_add_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_dft_add_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxDftAddInplace
+ VecZnxDftAlloc
@@ -155,7 +157,7 @@ where let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest = a.digest_u64(); let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); @@ -173,7 +175,7 @@ where for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, a_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let res_digest: u64 = res.digest_u64(); let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); @@ -206,17 +208,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -231,7 +235,7 @@ where } } -pub fn test_vec_znx_copy(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_copy(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxDftCopy
+ VecZnxDftAlloc
@@ -259,7 +263,7 @@ where for a_size in [1, 2, 6, 11] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest = a.digest_u64(); let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); @@ -307,17 +311,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -333,7 +339,7 @@ where } } -pub fn test_vec_znx_idft_apply(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_idft_apply(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxDftApply
+ VecZnxDftAlloc
@@ -361,7 +367,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -406,17 +412,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -432,7 +440,7 @@ where } } -pub fn test_vec_znx_idft_apply_tmpa(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_idft_apply_tmpa(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxDftApply
+ VecZnxDftAlloc
@@ -460,7 +468,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -494,17 +502,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -520,7 +530,7 @@ where } } -pub fn test_vec_znx_idft_apply_consume(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_idft_apply_consume(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxDftApply
+ VecZnxIdftApplyTmpBytes @@ -550,7 +560,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); for res_size in [1, 2, 3, 4] { @@ -579,17 +589,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -605,7 +617,7 @@ where } } -pub fn test_vec_znx_dft_sub(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_dft_sub(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: VecZnxDftSub
+ VecZnxDftAlloc
@@ -633,7 +645,7 @@ where for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest = a.digest_u64(); let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); @@ -651,7 +663,7 @@ where for b_size in [1, 2, 3, 4] { let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); - b.fill_uniform(basek, &mut source); + b.fill_uniform(base2k, &mut source); let b_digest: u64 = b.digest_u64(); let mut b_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, b_size); @@ -697,17 +709,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -723,15 +737,15 @@ where } } -pub fn test_vec_znx_dft_sub_ab_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vec_znx_dft_sub_inplace(base2k: usize, module_ref: &Module
, module_test: &Module) where - Module
: VecZnxDftSubABInplace
+ Module
: VecZnxDftSubInplace
+ VecZnxDftAlloc
+ VecZnxDftApply
+ VecZnxIdftApplyConsume
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, - Module: VecZnxDftSubABInplace + Module: VecZnxDftSubInplace + VecZnxDftAlloc + VecZnxDftApply + VecZnxIdftApplyConsume @@ -750,7 +764,7 @@ where let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest = a.digest_u64(); let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); @@ -768,7 +782,7 @@ where for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, a_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let res_digest: u64 = res.digest_u64(); let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); @@ -783,8 +797,8 @@ where // Reference for i in 0..cols { - module_ref.vec_znx_dft_sub_ab_inplace(&mut res_dft_ref, i, &a_dft_ref, i); - module_test.vec_znx_dft_sub_ab_inplace(&mut res_dft_test, i, &a_dft_test, i); + module_ref.vec_znx_dft_sub_inplace(&mut res_dft_ref, i, &a_dft_ref, i); + module_test.vec_znx_dft_sub_inplace(&mut res_dft_test, i, &a_dft_test, i); } assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); @@ -801,17 +815,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -826,15 +842,18 @@ where } } -pub fn test_vec_znx_dft_sub_ba_inplace(basek: usize, module_ref: &Module
, module_test: &Module) -where - Module
: VecZnxDftSubBAInplace
+pub fn test_vec_znx_dft_sub_negate_inplace( + base2k: usize, + module_ref: &Module
, + module_test: &Module, +) where + Module
: VecZnxDftSubNegateInplace
+ VecZnxDftAlloc
+ VecZnxDftApply
+ VecZnxIdftApplyConsume
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, - Module: VecZnxDftSubBAInplace + Module: VecZnxDftSubNegateInplace + VecZnxDftAlloc + VecZnxDftApply + VecZnxIdftApplyConsume @@ -853,7 +872,7 @@ where let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); for a_size in [1, 2, 3, 4] { let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest = a.digest_u64(); let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); @@ -871,7 +890,7 @@ where for res_size in [1, 2, 3, 4] { let mut res: VecZnx> = VecZnx::alloc(n, cols, a_size); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let res_digest: u64 = res.digest_u64(); let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); @@ -886,8 +905,8 @@ where // Reference for i in 0..cols { - module_ref.vec_znx_dft_sub_ba_inplace(&mut res_dft_ref, i, &a_dft_ref, i); - module_test.vec_znx_dft_sub_ba_inplace(&mut res_dft_test, i, &a_dft_test, i); + module_ref.vec_znx_dft_sub_negate_inplace(&mut res_dft_ref, i, &a_dft_ref, i); + module_test.vec_znx_dft_sub_negate_inplace(&mut res_dft_test, i, &a_dft_test, i); } assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); @@ -904,17 +923,19 @@ where for j in 0..cols { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), diff --git a/poulpy-hal/src/test_suite/vmp.rs b/poulpy-hal/src/test_suite/vmp.rs index e46d194..4cf0880 100644 --- a/poulpy-hal/src/test_suite/vmp.rs +++ b/poulpy-hal/src/test_suite/vmp.rs @@ -11,7 +11,7 @@ use rand::RngCore; use crate::layouts::{Backend, VecZnxDft, VmpPMat}; -pub fn test_vmp_apply_dft(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vmp_apply_dft(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: ModuleNew
+ VmpApplyDftTmpBytes @@ -53,11 +53,11 @@ where let rows: usize = cols_in; let mut a: VecZnx> = VecZnx::alloc(n, cols_in, size_in); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); let mut mat: MatZnx> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out); - mat.fill_uniform(basek, &mut source); + mat.fill_uniform(base2k, &mut source); let mat_digest: u64 = mat.digest_u64(); let mut pmat_ref: VmpPMat, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out); @@ -90,17 +90,19 @@ where for j in 0..cols_out { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -117,7 +119,7 @@ where } } -pub fn test_vmp_apply_dft_to_dft(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vmp_apply_dft_to_dft(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: ModuleNew
+ VmpApplyDftToDftTmpBytes @@ -162,7 +164,7 @@ where let rows: usize = size_in; let mut a: VecZnx> = VecZnx::alloc(n, cols_in, size_in); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in); @@ -176,7 +178,7 @@ where assert_eq!(a.digest_u64(), a_digest); let mut mat: MatZnx> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out); - mat.fill_uniform(basek, &mut source); + mat.fill_uniform(base2k, &mut source); let mat_digest: u64 = mat.digest_u64(); let mut pmat_ref: VmpPMat, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out); @@ -217,17 +219,19 @@ where for j in 0..cols_out { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), @@ -244,7 +248,7 @@ where } } -pub fn test_vmp_apply_dft_to_dft_add(basek: usize, module_ref: &Module
, module_test: &Module) +pub fn test_vmp_apply_dft_to_dft_add(base2k: usize, module_ref: &Module
, module_test: &Module) where Module
: ModuleNew
+ VmpApplyDftToDftAddTmpBytes @@ -289,7 +293,7 @@ where let rows: usize = size_in; let mut a: VecZnx> = VecZnx::alloc(n, cols_in, size_in); - a.fill_uniform(basek, &mut source); + a.fill_uniform(base2k, &mut source); let a_digest: u64 = a.digest_u64(); let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in); @@ -303,7 +307,7 @@ where assert_eq!(a.digest_u64(), a_digest); let mut mat: MatZnx> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out); - mat.fill_uniform(basek, &mut source); + mat.fill_uniform(base2k, &mut source); let mat_digest: u64 = mat.digest_u64(); let mut pmat_ref: VmpPMat, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out); @@ -316,7 +320,7 @@ where for limb_offset in 0..size_out { let mut res: VecZnx> = VecZnx::alloc(n, cols_out, size_out); - res.fill_uniform(basek, &mut source); + res.fill_uniform(base2k, &mut source); let res_digest: u64 = res.digest_u64(); let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out); @@ -355,17 +359,19 @@ where for j in 0..cols_out { module_ref.vec_znx_big_normalize( - basek, + base2k, &mut res_small_ref, j, + base2k, &res_big_ref, j, scratch_ref.borrow(), ); module_test.vec_znx_big_normalize( - basek, + base2k, &mut res_small_test, j, + base2k, &res_big_test, j, scratch_test.borrow(), diff --git a/poulpy-schemes/Cargo.toml b/poulpy-schemes/Cargo.toml index 7422f1e..9afa99e 100644 --- a/poulpy-schemes/Cargo.toml +++ b/poulpy-schemes/Cargo.toml @@ -13,5 +13,10 @@ documentation = "https://docs.rs/poulpy" poulpy-backend = {path="../poulpy-backend"} poulpy-hal = {path="../poulpy-hal"} poulpy-core = {path="../poulpy-core"} +criterion = {workspace = true} itertools = "0.14.0" -byteorder = "1.5.0" \ No newline at end of file +byteorder = "1.5.0" + +[[bench]] +name = "circuit_bootstrapping" +harness = false \ No newline at end of file diff --git a/poulpy-schemes/benches/circuit_bootstrapping.rs b/poulpy-schemes/benches/circuit_bootstrapping.rs new file mode 100644 index 0000000..d056938 --- /dev/null +++ b/poulpy-schemes/benches/circuit_bootstrapping.rs @@ -0,0 +1,307 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use poulpy_backend::{FFT64Avx, FFT64Ref, FFT64Spqlios}; +use poulpy_core::layouts::{ + Digits, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWESecret, LWECiphertext, + LWECiphertextLayout, LWESecret, prepared::PrepareAlloc, +}; +use poulpy_hal::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, + SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, + VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, + VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, + VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, + VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, + VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, + }, + source::Source, +}; +use poulpy_schemes::tfhe::{ + blind_rotation::{ + BlincRotationExecute, BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, + BlindRotationKeyInfos, BlindRotationKeyLayout, BlindRotationKeyPrepared, CGGI, + }, + circuit_bootstrapping::{ + CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyLayout, + CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute, + }, +}; + +pub fn benc_circuit_bootstrapping(c: &mut Criterion, label: &str) +where + Module: ModuleNew + + VecZnxFillUniform + + VecZnxAddNormal + + VecZnxNormalizeInplace + + VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxSubInplace + + VecZnxAddInplace + + VecZnxNormalize + + VecZnxSub + + VecZnxAddScalarInplace + + VecZnxAutomorphism + + VecZnxSwitchRing + + VecZnxBigAllocBytes + + VecZnxIdftApplyTmpA + + SvpApplyDftToDft + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigAlloc + + VecZnxDftAlloc + + VecZnxBigNormalizeTmpBytes + + VmpPMatAlloc + + VmpPrepare + + SvpPrepare + + SvpPPolAlloc + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + SvpPPolAllocBytes + + VecZnxRotateInplace + + VecZnxBigAutomorphismInplace + + VecZnxRshInplace + + VecZnxDftCopy + + VecZnxNegateInplace + + VecZnxCopy + + VecZnxAutomorphismInplace + + VecZnxBigSubSmallNegateInplace + + VecZnxRotateInplaceTmpBytes + + VecZnxBigAllocBytes + + VecZnxDftAddInplace + + VecZnxRotate + + ZnFillUniform + + ZnAddNormal + + ZnNormalizeInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + B: Backend + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + TakeVecZnxDftImpl + + ScratchAvailableImpl + + TakeVecZnxImpl + + TakeScalarZnxImpl + + TakeSvpPPolImpl + + TakeVecZnxBigImpl + + TakeVecZnxDftSliceImpl + + TakeMatZnxImpl + + TakeVecZnxSliceImpl, + BlindRotationKey, BRA>: PrepareAlloc, BRA, B>>, + BlindRotationKeyPrepared, BRA, B>: BlincRotationExecute, + BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, +{ + let group_name: String = format!("circuit_bootstrapping::{label}"); + + let mut group = c.benchmark_group(group_name); + + struct Params { + name: String, + extension_factor: usize, + k_pt: usize, + block_size: usize, + lwe_infos: LWECiphertextLayout, + ggsw_infos: GGSWCiphertextLayout, + cbt_infos: CircuitBootstrappingKeyLayout, + } + + fn runner(params: &Params) -> impl FnMut() + where + Module: ModuleNew + + VecZnxFillUniform + + VecZnxAddNormal + + VecZnxNormalizeInplace + + VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxSubInplace + + VecZnxAddInplace + + VecZnxNormalize + + VecZnxSub + + VecZnxAddScalarInplace + + VecZnxAutomorphism + + VecZnxSwitchRing + + VecZnxBigAllocBytes + + VecZnxIdftApplyTmpA + + SvpApplyDftToDft + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigAlloc + + VecZnxDftAlloc + + VecZnxBigNormalizeTmpBytes + + VmpPMatAlloc + + VmpPrepare + + SvpPrepare + + SvpPPolAlloc + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + SvpPPolAllocBytes + + VecZnxRotateInplace + + VecZnxBigAutomorphismInplace + + VecZnxRshInplace + + VecZnxDftCopy + + VecZnxNegateInplace + + VecZnxCopy + + VecZnxAutomorphismInplace + + VecZnxBigSubSmallNegateInplace + + VecZnxRotateInplaceTmpBytes + + VecZnxBigAllocBytes + + VecZnxDftAddInplace + + VecZnxRotate + + ZnFillUniform + + ZnAddNormal + + ZnNormalizeInplace, + B: Backend + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + TakeVecZnxDftImpl + + ScratchAvailableImpl + + TakeVecZnxImpl + + TakeScalarZnxImpl + + TakeSvpPPolImpl + + TakeVecZnxBigImpl + + TakeVecZnxDftSliceImpl + + TakeMatZnxImpl + + TakeVecZnxSliceImpl, + BlindRotationKey, BRA>: PrepareAlloc, BRA, B>>, + BlindRotationKeyPrepared, BRA, B>: BlincRotationExecute, + BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, + { + // Scratch space (4MB) + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let n_glwe: poulpy_core::layouts::Degree = params.cbt_infos.layout_brk.n_glwe(); + let n_lwe: poulpy_core::layouts::Degree = params.cbt_infos.layout_brk.n_lwe(); + let rank: poulpy_core::layouts::Rank = params.cbt_infos.layout_brk.rank; + + let module: Module = Module::::new(n_glwe.as_u32() as u64); + + let mut source_xs: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([1u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + sk_lwe.fill_binary_block(params.block_size, &mut source_xs); + sk_lwe.fill_zero(); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + + let ct_lwe: LWECiphertext> = LWECiphertext::alloc(¶ms.lwe_infos); + + // Circuit bootstrapping evaluation key + let cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::encrypt_sk( + &module, + &sk_lwe, + &sk_glwe, + ¶ms.cbt_infos, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(¶ms.ggsw_infos); + let cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, B> = cbt_key.prepare_alloc(&module, scratch.borrow()); + + move || { + cbt_prepared.execute_to_constant( + &module, + &mut res, + &ct_lwe, + params.k_pt, + params.extension_factor, + scratch.borrow(), + ); + black_box(()); + } + } + + for params in [Params { + name: String::from("1-bit"), + extension_factor: 1, + k_pt: 1, + lwe_infos: LWECiphertextLayout { + n: 574_u32.into(), + k: 13_u32.into(), + base2k: 13_u32.into(), + }, + block_size: 7, + ggsw_infos: GGSWCiphertextLayout { + n: 1024_u32.into(), + base2k: 13_u32.into(), + k: 26_u32.into(), + rows: 2_u32.into(), + digits: 1_u32.into(), + rank: 2_u32.into(), + }, + cbt_infos: CircuitBootstrappingKeyLayout { + layout_brk: BlindRotationKeyLayout { + n_glwe: 1024_u32.into(), + n_lwe: 574_u32.into(), + base2k: 13_u32.into(), + k: 52_u32.into(), + rows: 3_u32.into(), + rank: 2_u32.into(), + }, + layout_atk: GGLWEAutomorphismKeyLayout { + n: 1024_u32.into(), + base2k: 13_u32.into(), + k: 52_u32.into(), + rows: 3_u32.into(), + digits: Digits(1), + rank: 2_u32.into(), + }, + layout_tsk: GGLWETensorKeyLayout { + n: 1024_u32.into(), + base2k: 13_u32.into(), + k: 52_u32.into(), + rows: 3_u32.into(), + digits: Digits(1), + rank: 2_u32.into(), + }, + }, + }] { + let id: BenchmarkId = BenchmarkId::from_parameter(params.name.clone()); + let mut runner = runner::(¶ms); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +fn bench_circuit_bootstrapping_cpu_ref_fft64(c: &mut Criterion) { + benc_circuit_bootstrapping::(c, "fft64_ref"); +} + +fn bench_circuit_bootstrapping_cpu_avx_fft64(c: &mut Criterion) { + benc_circuit_bootstrapping::(c, "fft64_avx"); +} + +fn bench_circuit_bootstrapping_cpu_spqlios_fft64(c: &mut Criterion) { + benc_circuit_bootstrapping::(c, "fft64_spqlios"); +} + +criterion_group!( + benches, + bench_circuit_bootstrapping_cpu_ref_fft64, + bench_circuit_bootstrapping_cpu_avx_fft64, + bench_circuit_bootstrapping_cpu_spqlios_fft64, +); + +criterion_main!(benches); diff --git a/poulpy-schemes/examples/circuit_bootstrapping.rs b/poulpy-schemes/examples/circuit_bootstrapping.rs index f0e4e3f..b208287 100644 --- a/poulpy-schemes/examples/circuit_bootstrapping.rs +++ b/poulpy-schemes/examples/circuit_bootstrapping.rs @@ -1,7 +1,8 @@ use poulpy_core::{ GLWEOperations, layouts::{ - GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWEPlaintext, LWESecret, + GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertext, GGSWCiphertextLayout, GLWECiphertext, + GLWECiphertextLayout, GLWEPlaintext, GLWESecret, LWECiphertext, LWECiphertextLayout, LWEInfos, LWEPlaintext, LWESecret, prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, }, }; @@ -20,9 +21,10 @@ use poulpy_hal::{ }; use poulpy_schemes::tfhe::{ - blind_rotation::CGGI, + blind_rotation::{BlindRotationKeyLayout, CGGI}, circuit_bootstrapping::{ - CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute, + CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyLayout, + CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute, }, }; @@ -34,7 +36,7 @@ fn main() { let module: Module = Module::::new(n_glwe as u64); // Base 2 loga - let basek: usize = 13; + let base2k: usize = 13; // Lookup table extension factor let extension_factor: usize = 1; @@ -58,25 +60,67 @@ fn main() { let rows_ggsw_res: usize = 2; // GGSW output modulus - let k_ggsw_res: usize = (rows_ggsw_res + 1) * basek; + let k_ggsw_res: usize = (rows_ggsw_res + 1) * base2k; // Blind rotation key GGSW number of rows let rows_brk: usize = rows_ggsw_res + 1; // Blind rotation key GGSW modulus - let k_brk: usize = (rows_brk + 1) * basek; + let k_brk: usize = (rows_brk + 1) * base2k; // GGLWE automorphism keys number of rows let rows_trace: usize = rows_ggsw_res + 1; // GGLWE automorphism keys modulus - let k_trace: usize = (rows_trace + 1) * basek; + let k_trace: usize = (rows_trace + 1) * base2k; // GGLWE tensor key number of rows let rows_tsk: usize = rows_ggsw_res + 1; // GGLWE tensor key modulus - let k_tsk: usize = (rows_tsk + 1) * basek; + let k_tsk: usize = (rows_tsk + 1) * base2k; + + let cbt_infos: CircuitBootstrappingKeyLayout = CircuitBootstrappingKeyLayout { + layout_brk: BlindRotationKeyLayout { + n_glwe: n_glwe.into(), + n_lwe: n_lwe.into(), + base2k: base2k.into(), + k: k_brk.into(), + rows: rows_brk.into(), + rank: rank.into(), + }, + layout_atk: GGLWEAutomorphismKeyLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_trace.into(), + rows: rows_trace.into(), + digits: 1_u32.into(), + rank: rank.into(), + }, + layout_tsk: GGLWETensorKeyLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_tsk.into(), + rows: rows_tsk.into(), + digits: 1_u32.into(), + rank: rank.into(), + }, + }; + + let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_ggsw_res.into(), + rows: rows_ggsw_res.into(), + digits: 1_u32.into(), + rank: rank.into(), + }; + + let lwe_infos = LWECiphertextLayout { + n: n_lwe.into(), + k: k_lwe_ct.into(), + base2k: base2k.into(), + }; // Scratch space (4MB) let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); @@ -91,12 +135,12 @@ fn main() { let mut source_xe: Source = Source::new([1u8; 32]); // LWE secret - let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); sk_lwe.fill_zero(); // GLWE secret - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n_glwe, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe.into(), rank.into()); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); // sk_glwe.fill_zero(); @@ -107,17 +151,23 @@ fn main() { let data: i64 = 1 % (1 << k_lwe_pt); // LWE plaintext - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); // LWE plaintext(data * 2^{- (k_lwe_pt - 1)}) - pt_lwe.encode_i64(data, k_lwe_pt + 1); // +1 for padding bit + pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); // +1 for padding bit // Normalize plaintext to nicely print coefficients - module.zn_normalize_inplace(pt_lwe.n(), basek, pt_lwe.data_mut(), 0, scratch.borrow()); - println!("pt_lwe: {}", pt_lwe); + module.zn_normalize_inplace( + pt_lwe.n().into(), + base2k, + pt_lwe.data_mut(), + 0, + scratch.borrow(), + ); + println!("pt_lwe: {pt_lwe}"); // LWE ciphertext - let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); // Encrypt LWE Plaintext ct_lwe.encrypt_sk(&module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); @@ -127,15 +177,9 @@ fn main() { // Circuit bootstrapping evaluation key let cbt_key: CircuitBootstrappingKey, CGGI> = CircuitBootstrappingKey::encrypt_sk( &module, - basek, &sk_lwe, &sk_glwe, - k_brk, - rows_brk, - k_trace, - rows_trace, - k_tsk, - rows_tsk, + &cbt_infos, &mut source_xa, &mut source_xe, scratch.borrow(), @@ -143,7 +187,7 @@ fn main() { println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); // Output GGSW - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(n_glwe, basek, k_ggsw_res, rows_ggsw_res, 1, rank); + let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); // Circuit bootstrapping key prepared (opaque backend dependant write only struct) let cbt_prepared: CircuitBootstrappingKeyPrepared, CGGI, BackendImpl> = @@ -170,19 +214,26 @@ fn main() { // Tests RLWE(1) * GGSW(data) + let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: (k_ggsw_res - base2k).into(), + rank: rank.into(), + }; + // GLWE ciphertext modulus - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n_glwe, basek, k_ggsw_res - basek, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); // Some GLWE plaintext with signed data let k_glwe_pt: usize = 3; - let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(n_glwe, basek, basek); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); let mut data_vec: Vec = vec![0i64; n_glwe]; data_vec .iter_mut() .enumerate() .for_each(|(x, y)| *y = (x % (1 << (k_glwe_pt - 1))) as i64 - (1 << (k_glwe_pt - 2))); - pt_glwe.encode_vec_i64(&data_vec, k_lwe_pt + 2); + pt_glwe.encode_vec_i64(&data_vec, (k_lwe_pt + 2).into()); pt_glwe.normalize_inplace(&module, scratch.borrow()); println!("{}", pt_glwe); @@ -204,7 +255,7 @@ fn main() { ct_glwe.external_product_inplace(&module, &res_prepared, scratch.borrow()); // Decrypt - let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(n_glwe, basek, ct_glwe.k()); + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); ct_glwe.decrypt(&module, &mut pt_res, &sk_glwe_prepared, scratch.borrow()); println!("pt_res: {:?}", &pt_res.data.at(0, 0)[..64]); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs index a8bfb87..3814b48 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs @@ -4,34 +4,33 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDft, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxDftSubABInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, + VecZnxDftSubInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxView, ZnxZero}, }; use poulpy_core::{ Distribution, GLWEOperations, TakeGLWECt, - layouts::{GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, LWECiphertextToRef}, + layouts::{GGSWInfos, GLWECiphertext, GLWECiphertextToMut, GLWEInfos, LWECiphertext, LWECiphertextToRef, LWEInfos}, }; use crate::tfhe::blind_rotation::{ - BlincRotationExecute, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection, + BlincRotationExecute, BlindRotationKeyInfos, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection, }; #[allow(clippy::too_many_arguments)] -pub fn cggi_blind_rotate_scratch_space( +pub fn cggi_blind_rotate_scratch_space( module: &Module, block_size: usize, extension_factor: usize, - basek: usize, - k_res: usize, - k_brk: usize, - rows: usize, - rank: usize, + glwe_infos: &OUT, + brk_infos: &GGSW, ) -> usize where + OUT: GLWEInfos, + GGSW: GGSWInfos, Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes @@ -39,10 +38,11 @@ where + VecZnxIdftApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - let brk_size: usize = k_brk.div_ceil(basek); + let brk_size: usize = brk_infos.size(); if block_size > 1 { - let cols: usize = rank + 1; + let cols: usize = (brk_infos.rank() + 1).into(); + let rows: usize = brk_infos.rows().into(); let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor; let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size); let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor; @@ -50,7 +50,7 @@ where let acc_dft_add: usize = vmp_res; let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) let acc: usize = if extension_factor > 1 { - VecZnx::alloc_bytes(module.n(), cols, k_res.div_ceil(basek)) * extension_factor + VecZnx::alloc_bytes(module.n(), cols, glwe_infos.size()) * extension_factor } else { 0 }; @@ -61,8 +61,8 @@ where + vmp_xai + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes()))) } else { - GLWECiphertext::bytes_of(module.n(), basek, k_res, rank) - + GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank) + GLWECiphertext::alloc_bytes(glwe_infos) + + GLWECiphertext::external_product_inplace_scratch_space(module, glwe_infos, brk_infos) } } @@ -80,11 +80,11 @@ where + VecZnxDftApply + VecZnxDftZero + SvpApplyDftToDft - + VecZnxDftSubABInplace + + VecZnxDftSubInplace + VecZnxBigAddSmallInplace + VecZnxRotate + VecZnxAddInplace - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy @@ -142,11 +142,11 @@ fn execute_block_binary_extended( + VecZnxDftApply + VecZnxDftZero + SvpApplyDftToDft - + VecZnxDftSubABInplace + + VecZnxDftSubInplace + VecZnxBigAddSmallInplace + VecZnxRotate + VecZnxAddInplace - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy @@ -155,11 +155,11 @@ fn execute_block_binary_extended( + VmpApplyDftToDft, Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, { - let n_glwe: usize = brk.n(); + let n_glwe: usize = brk.n_glwe().into(); let extension_factor: usize = lut.extension_factor(); - let basek: usize = res.basek(); - let rows: usize = brk.rows(); - let cols: usize = res.rank() + 1; + let base2k: usize = res.base2k().into(); + let rows: usize = brk.rows().into(); + let cols: usize = (res.rank() + 1).into(); let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size()); let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, rows); @@ -178,7 +178,7 @@ fn execute_block_binary_extended( panic!("invalid key: x_pow_a has not been initialized") } - let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).as_usize()]; // TODO: from scratch space let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); let two_n: usize = 2 * n_glwe; @@ -233,7 +233,7 @@ fn execute_block_binary_extended( (0..cols).for_each(|i| { module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], i); module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_xai, 0); - module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i); + module.vec_znx_dft_sub_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i); }); }); } @@ -249,7 +249,7 @@ fn execute_block_binary_extended( (0..cols).for_each(|k| { module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi + 1], 0, &vmp_res[j], k); module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0); - module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); + module.vec_znx_dft_sub_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); }); } } @@ -261,7 +261,7 @@ fn execute_block_binary_extended( (0..cols).for_each(|k| { module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], k); module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0); - module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); + module.vec_znx_dft_sub_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); }); } } @@ -275,14 +275,14 @@ fn execute_block_binary_extended( (0..cols).for_each(|i| { module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7); module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i); - module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7); + module.vec_znx_big_normalize(base2k, &mut acc[j], i, base2k, &acc_add_big, 0, scratch7); }); }); } }); (0..cols).for_each(|i| { - module.vec_znx_copy(&mut res.data, i, &acc[0], i); + module.vec_znx_copy(res.data_mut(), i, &acc[0], i); }); } @@ -309,11 +309,11 @@ fn execute_block_binary( + VecZnxDftApply + VecZnxDftZero + SvpApplyDftToDft - + VecZnxDftSubABInplace + + VecZnxDftSubInplace + VecZnxBigAddSmallInplace + VecZnxRotate + VecZnxAddInplace - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy @@ -322,15 +322,15 @@ fn execute_block_binary( + VecZnxBigNormalize, Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, { - let n_glwe: usize = brk.n(); - let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let n_glwe: usize = brk.n_glwe().into(); + let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); let two_n: usize = n_glwe << 1; - let basek: usize = brk.basek(); - let rows: usize = brk.rows(); + let base2k: usize = brk.base2k().into(); + let rows: usize = brk.rows().into(); - let cols: usize = out_mut.rank() + 1; + let cols: usize = (out_mut.rank() + 1).into(); mod_switch_2n( 2 * lut.domain_size(), @@ -342,10 +342,10 @@ fn execute_block_binary( let a: &[i64] = &lwe_2n[1..]; let b: i64 = lwe_2n[0]; - out_mut.data.zero(); + out_mut.data_mut().zero(); // Initialize out to X^{b} * LUT(X) - module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); + module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0); let block_size: usize = brk.block_size(); @@ -369,7 +369,7 @@ fn execute_block_binary( ) .for_each(|(ai, ski)| { (0..cols).for_each(|j| { - module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, &out_mut.data, j); + module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, out_mut.data_mut(), j); }); module.vec_znx_dft_zero(&mut acc_add_dft); @@ -384,7 +384,7 @@ fn execute_block_binary( (0..cols).for_each(|i| { module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_pos], 0, &vmp_res, i); module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_xai, 0); - module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft, i, &vmp_res, i); + module.vec_znx_dft_sub_inplace(&mut acc_add_dft, i, &vmp_res, i); }); }); @@ -393,8 +393,16 @@ fn execute_block_binary( (0..cols).for_each(|i| { module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft, i, scratch_5); - module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i); - module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch_5); + module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, out_mut.data_mut(), i); + module.vec_znx_big_normalize( + base2k, + out_mut.data_mut(), + i, + base2k, + &acc_add_big, + 0, + scratch_5, + ); }); } }); @@ -423,11 +431,11 @@ fn execute_standard( + VecZnxDftApply + VecZnxDftZero + SvpApplyDftToDft - + VecZnxDftSubABInplace + + VecZnxDftSubInplace + VecZnxBigAddSmallInplace + VecZnxRotate + VecZnxAddInplace - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy @@ -450,10 +458,10 @@ fn execute_standard( ); assert_eq!( lut.domain_size(), - brk.n(), + brk.n_glwe().as_usize(), "lut.n(): {} != brk.n(): {}", lut.domain_size(), - brk.n() + brk.n_glwe().as_usize() ); assert_eq!( res.rank(), @@ -464,17 +472,16 @@ fn execute_standard( ); assert_eq!( lwe.n(), - brk.data.len(), + brk.n_lwe(), "lwe.n(): {} != brk.data.len(): {}", lwe.n(), - brk.data.len() + brk.n_lwe() ); } - let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); - let basek: usize = brk.basek(); mod_switch_2n( 2 * lut.domain_size(), @@ -486,13 +493,13 @@ fn execute_standard( let a: &[i64] = &lwe_2n[1..]; let b: i64 = lwe_2n[0]; - out_mut.data.zero(); + out_mut.data_mut().zero(); // Initialize out to X^{b} * LUT(X) - module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); + module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0); // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(out_mut.n(), basek, out_mut.k(), out_mut.rank()); + let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(&out_mut); // TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs // TODO: first iteration can be optimized to be a gglwe product @@ -507,13 +514,13 @@ fn execute_standard( out_mut.add_inplace(module, &acc_tmp); }); - // We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}] - // on top of each others, thus ~ 2^{63-basek} additions are supported before overflow. + // We can normalize only at the end because we add normalized values in [-2^{base2k-1}, 2^{base2k-1}] + // on top of each others, thus ~ 2^{63-base2k} additions are supported before overflow. out_mut.normalize_inplace(module, scratch_1); } pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) { - let basek: usize = lwe.basek(); + let base2k: usize = lwe.base2k().into(); let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; @@ -526,23 +533,23 @@ pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_ LookUpTableRotationDirection::Right => {} } - if basek > log2n { - let diff: usize = basek - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N) + if base2k > log2n { + let diff: usize = base2k - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N) res.iter_mut().for_each(|x| { *x = div_round_by_pow2(x, diff); }) } else { - let rem: usize = basek - (log2n % basek); - let size: usize = log2n.div_ceil(basek); + let rem: usize = base2k - (log2n % base2k); + let size: usize = log2n.div_ceil(base2k); (1..size).for_each(|i| { - if i == size - 1 && rem != basek { - let k_rem: usize = basek - rem; + if i == size - 1 && rem != base2k { + let k_rem: usize = base2k - rem; izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| { - *y = (*y << basek) + x; + *y = (*y << base2k) + x; }); } }) diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs index 1f33341..fbf506b 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs @@ -2,7 +2,7 @@ use poulpy_hal::{ api::{ ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}, @@ -14,21 +14,27 @@ use std::marker::PhantomData; use poulpy_core::{ Distribution, layouts::{ - GGSWCiphertext, LWESecret, + GGSWCiphertext, GGSWInfos, LWESecret, compressed::GGSWCiphertextCompressed, prepared::{GGSWCiphertextPrepared, GLWESecretPrepared}, }, }; use crate::tfhe::blind_rotation::{ - BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyEncryptSk, BlindRotationKeyPrepared, - BlindRotationKeyPreparedAlloc, CGGI, + BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyEncryptSk, BlindRotationKeyInfos, + BlindRotationKeyPrepared, BlindRotationKeyPreparedAlloc, CGGI, }; impl BlindRotationKeyAlloc for BlindRotationKey, CGGI> { - fn alloc(n_gglwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { - let mut data: Vec>> = Vec::with_capacity(n_lwe); - (0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(n_gglwe, basek, k, rows, 1, rank))); + fn alloc
(infos: &A) -> Self + where + A: BlindRotationKeyInfos, + { + let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); + for _ in 0..infos.n_lwe().as_usize() { + data.push(GGSWCiphertext::alloc(infos)); + } + Self { keys: data, dist: Distribution::NONE, @@ -38,11 +44,12 @@ impl BlindRotationKeyAlloc for BlindRotationKey, CGGI> { } impl BlindRotationKey, CGGI> { - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn generate_from_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGSWInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, infos) } } @@ -56,7 +63,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -78,9 +85,11 @@ where { #[cfg(debug_assertions)] { - assert_eq!(self.keys.len(), sk_lwe.n()); - assert!(sk_glwe.n() <= module.n()); - assert_eq!(sk_glwe.rank(), self.keys[0].rank()); + use poulpy_core::layouts::{GLWEInfos, LWEInfos}; + + assert_eq!(self.keys.len() as u32, sk_lwe.n()); + assert!(sk_glwe.n() <= module.n() as u32); + assert_eq!(sk_glwe.rank(), self.rank()); match sk_lwe.dist() { Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) @@ -94,7 +103,7 @@ where self.dist = sk_lwe.dist(); - let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n(), 1); + let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| { @@ -108,13 +117,12 @@ impl BlindRotationKeyPreparedAlloc for BlindRotationKeyPrepared: VmpPMatAlloc + VmpPrepare, { - fn alloc(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { - let mut data: Vec, B>> = Vec::with_capacity(n_lwe); - (0..n_lwe).for_each(|_| { - data.push(GGSWCiphertextPrepared::alloc( - module, basek, k, rows, 1, rank, - )) - }); + fn alloc(module: &Module, infos: &A) -> Self + where + A: BlindRotationKeyInfos, + { + let mut data: Vec, B>> = Vec::with_capacity(infos.n_lwe().into()); + (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCiphertextPrepared::alloc(module, infos))); Self { data, dist: Distribution::NONE, @@ -125,13 +133,12 @@ where } impl BlindRotationKeyCompressed, CGGI> { - pub fn alloc(n_gglwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { - let mut data: Vec>> = Vec::with_capacity(n_lwe); - (0..n_lwe).for_each(|_| { - data.push(GGSWCiphertextCompressed::alloc( - n_gglwe, basek, k, rows, 1, rank, - )) - }); + pub fn alloc(infos: &A) -> Self + where + A: BlindRotationKeyInfos, + { + let mut data: Vec>> = Vec::with_capacity(infos.n_lwe().into()); + (0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCiphertextCompressed::alloc(infos))); Self { keys: data, dist: Distribution::NONE, @@ -139,11 +146,12 @@ impl BlindRotationKeyCompressed, CGGI> { } } - pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize + pub fn generate_from_sk_scratch_space(module: &Module, infos: &A) -> usize where + A: GGSWInfos, Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, basek, k, rank) + GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, infos) } } @@ -168,7 +176,7 @@ impl BlindRotationKeyCompressed { + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -178,9 +186,11 @@ impl BlindRotationKeyCompressed { { #[cfg(debug_assertions)] { - assert_eq!(self.keys.len(), sk_lwe.n()); - assert!(sk_glwe.n() <= module.n()); - assert_eq!(sk_glwe.rank(), self.keys[0].rank()); + use poulpy_core::layouts::{GLWEInfos, LWEInfos}; + + assert_eq!(self.n_lwe(), sk_lwe.n()); + assert!(sk_glwe.n() <= module.n() as u32); + assert_eq!(sk_glwe.rank(), self.rank()); match sk_lwe.dist() { Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) @@ -194,7 +204,7 @@ impl BlindRotationKeyCompressed { self.dist = sk_lwe.dist(); - let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n(), 1); + let mut pt: ScalarZnx> = ScalarZnx::alloc(sk_glwe.n().into(), 1); let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref(); let mut source_xa: Source = Source::new(seed_xa); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/key.rs index 6b6163a..b4fbda8 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, Reset, Scratch, WriterTo}, + layouts::{Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, Scratch, WriterTo}, source::Source, }; @@ -7,15 +7,78 @@ use std::{fmt, marker::PhantomData}; use poulpy_core::{ Distribution, - layouts::{GGSWCiphertext, Infos, LWESecret, prepared::GLWESecretPrepared}, + layouts::{ + Base2K, Degree, Digits, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, LWESecret, Rank, Rows, TorusPrecision, + prepared::GLWESecretPrepared, + }, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use crate::tfhe::blind_rotation::BlindRotationAlgo; +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct BlindRotationKeyLayout { + pub n_glwe: Degree, + pub n_lwe: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rows: Rows, + pub rank: Rank, +} + +impl BlindRotationKeyInfos for BlindRotationKeyLayout { + fn n_glwe(&self) -> Degree { + self.n_glwe + } + + fn n_lwe(&self) -> Degree { + self.n_lwe + } +} + +impl GGSWInfos for BlindRotationKeyLayout { + fn digits(&self) -> Digits { + Digits(1) + } + + fn rows(&self) -> Rows { + self.rows + } +} + +impl GLWEInfos for BlindRotationKeyLayout { + fn rank(&self) -> Rank { + self.rank + } +} + +impl LWEInfos for BlindRotationKeyLayout { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + self.n_glwe + } +} + +pub trait BlindRotationKeyInfos +where + Self: GGSWInfos, +{ + fn n_glwe(&self) -> Degree; + fn n_lwe(&self) -> Degree; +} + pub trait BlindRotationKeyAlloc { - fn alloc(n_gglwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self; + fn alloc(infos: &A) -> Self + where + A: BlindRotationKeyInfos; } pub trait BlindRotationKeyEncryptSk { @@ -42,7 +105,7 @@ pub struct BlindRotationKey { impl fmt::Debug for BlindRotationKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -66,19 +129,12 @@ impl Eq for BlindRotationKey {} impl fmt::Display for BlindRotationKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for (i, key) in self.keys.iter().enumerate() { - write!(f, "key[{}]: {}", i, key)?; + write!(f, "key[{i}]: {key}")?; } writeln!(f, "{:?}", self.dist) } } -impl Reset for BlindRotationKey { - fn reset(&mut self) { - self.keys.iter_mut().for_each(|key| key.reset()); - self.dist = Distribution::NONE; - } -} - impl FillUniform for BlindRotationKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys @@ -121,41 +177,55 @@ impl WriterTo for BlindRotationKey { } } +impl BlindRotationKeyInfos for BlindRotationKey { + fn n_glwe(&self) -> Degree { + self.n() + } + + fn n_lwe(&self) -> Degree { + Degree(self.keys.len() as u32) + } +} + impl BlindRotationKey { #[allow(dead_code)] - pub(crate) fn n(&self) -> usize { - self.keys[0].n() - } - - #[allow(dead_code)] - pub(crate) fn rows(&self) -> usize { - self.keys[0].rows() - } - - #[allow(dead_code)] - pub(crate) fn k(&self) -> usize { - self.keys[0].k() - } - - #[allow(dead_code)] - pub(crate) fn size(&self) -> usize { - self.keys[0].size() - } - - #[allow(dead_code)] - pub(crate) fn rank(&self) -> usize { - self.keys[0].rank() - } - - pub(crate) fn basek(&self) -> usize { - self.keys[0].basek() - } - - #[allow(dead_code)] - pub(crate) fn block_size(&self) -> usize { + fn block_size(&self) -> usize { match self.dist { Distribution::BinaryBlock(value) => value, _ => 1, } } } + +impl LWEInfos for BlindRotationKey { + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for BlindRotationKey { + fn rank(&self) -> Rank { + self.keys[0].rank() + } +} +impl GGSWInfos for BlindRotationKey { + fn digits(&self) -> poulpy_core::layouts::Digits { + Digits(1) + } + + fn rows(&self) -> Rows { + self.keys[0].rows() + } +} diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs index 7fa463e..22a98da 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, Reset, WriterTo}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, source::Source, }; @@ -8,10 +8,10 @@ use std::{fmt, marker::PhantomData}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use poulpy_core::{ Distribution, - layouts::{Infos, compressed::GGSWCiphertextCompressed}, + layouts::{Base2K, Degree, Digits, GGSWInfos, GLWEInfos, LWEInfos, TorusPrecision, compressed::GGSWCiphertextCompressed}, }; -use crate::tfhe::blind_rotation::BlindRotationAlgo; +use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKeyInfos}; #[derive(Clone)] pub struct BlindRotationKeyCompressed { @@ -22,7 +22,7 @@ pub struct BlindRotationKeyCompressed { impl fmt::Debug for BlindRotationKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -45,19 +45,12 @@ impl Eq for BlindRotationKeyCompressed impl fmt::Display for BlindRotationKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for (i, key) in self.keys.iter().enumerate() { - write!(f, "key[{}]: {}", i, key)?; + write!(f, "key[{i}]: {key}")?; } writeln!(f, "{:?}", self.dist) } } -impl Reset for BlindRotationKeyCompressed { - fn reset(&mut self) { - self.keys.iter_mut().for_each(|key| key.reset()); - self.dist = Distribution::NONE; - } -} - impl FillUniform for BlindRotationKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys @@ -100,37 +93,51 @@ impl WriterTo for BlindRotationKeyCompressed } } -impl BlindRotationKeyCompressed { - #[allow(dead_code)] - pub(crate) fn n(&self) -> usize { +impl BlindRotationKeyInfos for BlindRotationKeyCompressed { + fn n_glwe(&self) -> Degree { + self.n() + } + + fn n_lwe(&self) -> Degree { + Degree(self.keys.len() as u32) + } +} + +impl LWEInfos for BlindRotationKeyCompressed { + fn n(&self) -> Degree { self.keys[0].n() } - #[allow(dead_code)] - pub(crate) fn rows(&self) -> usize { - self.keys[0].rows() - } - - #[allow(dead_code)] - pub(crate) fn k(&self) -> usize { - self.keys[0].k() - } - - #[allow(dead_code)] - pub(crate) fn size(&self) -> usize { + fn size(&self) -> usize { self.keys[0].size() } - #[allow(dead_code)] - pub(crate) fn rank(&self) -> usize { + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } +} + +impl GLWEInfos for BlindRotationKeyCompressed { + fn rank(&self) -> poulpy_core::layouts::Rank { self.keys[0].rank() } +} - #[allow(dead_code)] - pub(crate) fn basek(&self) -> usize { - self.keys[0].basek() +impl GGSWInfos for BlindRotationKeyCompressed { + fn rows(&self) -> poulpy_core::layouts::Rows { + self.keys[0].rows() } + fn digits(&self) -> poulpy_core::layouts::Digits { + Digits(1) + } +} + +impl BlindRotationKeyCompressed { #[allow(dead_code)] pub(crate) fn block_size(&self) -> usize { match self.dist { diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs index 6431167..d7dad82 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs @@ -8,15 +8,17 @@ use std::marker::PhantomData; use poulpy_core::{ Distribution, layouts::{ - Infos, + Base2K, Degree, Digits, GGSWInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision, prepared::{GGSWCiphertextPrepared, Prepare, PrepareAlloc}, }, }; -use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, utils::set_xai_plus_y}; +use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyInfos, utils::set_xai_plus_y}; pub trait BlindRotationKeyPreparedAlloc { - fn alloc(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self; + fn alloc(module: &Module, infos: &A) -> Self + where + A: BlindRotationKeyInfos; } #[derive(PartialEq, Eq)] @@ -27,37 +29,51 @@ pub struct BlindRotationKeyPrepared pub(crate) _phantom: PhantomData, } -impl BlindRotationKeyPrepared { - #[allow(dead_code)] - pub(crate) fn n(&self) -> usize { - self.data[0].n() +impl BlindRotationKeyInfos for BlindRotationKeyPrepared { + fn n_glwe(&self) -> Degree { + self.n() } - #[allow(dead_code)] - pub(crate) fn rows(&self) -> usize { - self.data[0].rows() + fn n_lwe(&self) -> Degree { + Degree(self.data.len() as u32) + } +} + +impl LWEInfos for BlindRotationKeyPrepared { + fn base2k(&self) -> Base2K { + self.data[0].base2k() } - #[allow(dead_code)] - pub(crate) fn k(&self) -> usize { + fn k(&self) -> TorusPrecision { self.data[0].k() } - #[allow(dead_code)] - pub(crate) fn size(&self) -> usize { + fn n(&self) -> Degree { + self.data[0].n() + } + + fn size(&self) -> usize { self.data[0].size() } +} - #[allow(dead_code)] - pub(crate) fn rank(&self) -> usize { +impl GLWEInfos for BlindRotationKeyPrepared { + fn rank(&self) -> Rank { self.data[0].rank() } - - pub(crate) fn basek(&self) -> usize { - self.data[0].basek() +} +impl GGSWInfos for BlindRotationKeyPrepared { + fn digits(&self) -> poulpy_core::layouts::Digits { + Digits(1) } - pub(crate) fn block_size(&self) -> usize { + fn rows(&self) -> Rows { + self.data[0].rows() + } +} + +impl BlindRotationKeyPrepared { + pub fn block_size(&self) -> usize { match self.dist { Distribution::BinaryBlock(value) => value, _ => 1, @@ -72,14 +88,7 @@ where BlindRotationKeyPrepared, BRA, B>: Prepare>, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> BlindRotationKeyPrepared, BRA, B> { - let mut brk: BlindRotationKeyPrepared, BRA, B> = BlindRotationKeyPrepared::alloc( - module, - self.keys.len(), - self.basek(), - self.k(), - self.rows(), - self.rank(), - ); + let mut brk: BlindRotationKeyPrepared, BRA, B> = BlindRotationKeyPrepared::alloc(module, self); brk.prepare(module, self, scratch); brk } @@ -96,7 +105,7 @@ where assert_eq!(self.data.len(), other.keys.len()); } - let n: usize = other.n(); + let n: usize = other.n().as_usize(); self.data .iter_mut() diff --git a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs index 4ff19f8..106bcff 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs @@ -15,29 +15,28 @@ pub enum LookUpTableRotationDirection { pub struct LookUpTable { pub(crate) data: Vec>>, pub(crate) rot_dir: LookUpTableRotationDirection, - pub(crate) basek: usize, + pub(crate) base2k: usize, pub(crate) k: usize, pub(crate) drift: usize, } impl LookUpTable { - pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self { + pub fn alloc(module: &Module, base2k: usize, k: usize, extension_factor: usize) -> Self { #[cfg(debug_assertions)] { assert!( extension_factor & (extension_factor - 1) == 0, - "extension_factor must be a power of two but is: {}", - extension_factor + "extension_factor must be a power of two but is: {extension_factor}" ); } - let size: usize = k.div_ceil(basek); + let size: usize = k.div_ceil(base2k); let mut data: Vec>> = Vec::with_capacity(extension_factor); (0..extension_factor).for_each(|_| { data.push(VecZnx::alloc(module.n(), 1, size)); }); Self { data, - basek, + base2k, k, drift: 0, rot_dir: LookUpTableRotationDirection::Left, @@ -80,27 +79,27 @@ impl LookUpTable { { assert!(f.len() <= module.n()); - let basek: usize = self.basek; + let base2k: usize = self.base2k; let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes()); // Get the number minimum limb to store the message modulus - let limbs: usize = k.div_ceil(basek); + let limbs: usize = k.div_ceil(base2k); #[cfg(debug_assertions)] { assert!(f.len() <= module.n()); assert!( - (max_bit_size(f) + (k % basek) as u32) < i64::BITS, - "overflow: max(|f|) << (k%basek) > i64::BITS" + (max_bit_size(f) + (k % base2k) as u32) < i64::BITS, + "overflow: max(|f|) << (k%base2k) > i64::BITS" ); assert!(limbs <= self.data[0].size()); } // Scaling factor let mut scale = 1; - if !k.is_multiple_of(basek) { - scale <<= basek - (k % basek); + if !k.is_multiple_of(base2k) { + scale <<= base2k - (k % base2k); } // #elements in lookup table @@ -109,7 +108,7 @@ impl LookUpTable { // If LUT size > TakeScalarZnx let domain_size: usize = self.domain_size(); - let size: usize = self.k.div_ceil(self.basek); + let size: usize = self.k.div_ceil(self.base2k); // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) let mut lut_full: VecZnx> = VecZnx::alloc(domain_size, 1, size); @@ -140,7 +139,7 @@ impl LookUpTable { } self.data.iter_mut().for_each(|a| { - module.vec_znx_normalize_inplace(self.basek, a, 0, scratch.borrow()); + module.vec_znx_normalize_inplace(self.base2k, a, 0, scratch.borrow()); }); self.rotate(module, -(drift as i64)); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index 4304a7f..f2fc246 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -3,9 +3,9 @@ use poulpy_hal::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftSubABInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIdftApply, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftSubInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, }, @@ -18,12 +18,13 @@ use poulpy_hal::{ }; use crate::tfhe::blind_rotation::{ - BlincRotationExecute, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyPrepared, CGGI, - LookUpTable, cggi_blind_rotate_scratch_space, mod_switch_2n, + BlincRotationExecute, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyLayout, + BlindRotationKeyPrepared, CGGI, LookUpTable, cggi_blind_rotate_scratch_space, mod_switch_2n, }; use poulpy_core::layouts::{ - GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWECiphertextToRef, LWEPlaintext, LWESecret, + GLWECiphertext, GLWECiphertextLayout, GLWEPlaintext, GLWESecret, LWECiphertext, LWECiphertextLayout, LWECiphertextToRef, + LWEInfos, LWEPlaintext, LWESecret, prepared::{GLWESecretPrepared, PrepareAlloc}, }; @@ -41,11 +42,11 @@ where + VecZnxDftApply + VecZnxDftZero + SvpApplyDftToDft - + VecZnxDftSubABInplace + + VecZnxDftSubInplace + VecZnxBigAddSmallInplace + VecZnxRotate + VecZnxAddInplace - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy @@ -83,16 +84,16 @@ where + TakeVecZnxImpl + TakeVecZnxSliceImpl, { - let n: usize = module.n(); - let basek: usize = 19; + let n_glwe: usize = module.n(); + let base2k: usize = 19; let k_lwe: usize = 24; - let k_brk: usize = 3 * basek; + let k_brk: usize = 3 * base2k; let rows_brk: usize = 2; // Ensures first limb is noise-free. - let k_lut: usize = basek; - let k_res: usize = 2 * basek; + let k_lut: usize = base2k; + let k_res: usize = 2 * base2k; let rank: usize = 1; - let log_message_modulus = 4; + let log_message_modulus: usize = 4; let message_modulus: usize = 1 << log_message_modulus; @@ -100,30 +101,48 @@ where let mut source_xe: Source = Source::new([2u8; 32]); let mut source_xa: Source = Source::new([1u8; 32]); + let brk_infos: BlindRotationKeyLayout = BlindRotationKeyLayout { + n_glwe: n_glwe.into(), + n_lwe: n_lwe.into(), + base2k: base2k.into(), + k: k_brk.into(), + rows: rows_brk.into(), + rank: rank.into(), + }; + + let glwe_infos: GLWECiphertextLayout = GLWECiphertextLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_res.into(), + rank: rank.into(), + }; + + let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + n: n_lwe.into(), + k: k_lwe.into(), + base2k: base2k.into(), + }; + let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::generate_from_sk_scratch_space( - module, basek, k_brk, rank, + module, &brk_infos, )); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&glwe_infos); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_dft: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); - let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_scratch_space( module, block_size, extension_factor, - basek, - k_res, - k_brk, - rows_brk, - rank, + &glwe_infos, + &brk_infos, )); - let mut brk: BlindRotationKey, CGGI> = - BlindRotationKey::, CGGI>::alloc(n, n_lwe, basek, k_brk, rows_brk, rank); + let mut brk: BlindRotationKey, CGGI> = BlindRotationKey::, CGGI>::alloc(&brk_infos); brk.encrypt_sk( module, @@ -134,13 +153,13 @@ where scratch.borrow(), ); - let mut lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe); + let mut lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(&lwe_infos); let x: i64 = 15 % (message_modulus as i64); - pt_lwe.encode_i64(x, log_message_modulus + 1); + pt_lwe.encode_i64(x, (log_message_modulus + 1).into()); lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); @@ -152,20 +171,20 @@ where .enumerate() .for_each(|(i, x)| *x = f(i as i64)); - let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor); + let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); lut.set(module, &f_vec, log_message_modulus + 1); - let mut res: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_res, rank); + let mut res: GLWECiphertext> = GLWECiphertext::alloc(&glwe_infos); let brk_prepared: BlindRotationKeyPrepared, CGGI, B> = brk.prepare_alloc(module, scratch.borrow()); brk_prepared.execute(module, &mut res, &lwe, &lut, scratch_br.borrow()); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_res); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(&glwe_infos); res.decrypt(module, &mut pt_have, &sk_glwe_dft, scratch.borrow()); - let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut lwe_2n: Vec = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space mod_switch_2n( 2 * lut.domain_size(), @@ -189,7 +208,7 @@ where assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0)); // Verify that it effectively compute f(x) - let mut have: i64 = pt_have.decode_coeff_i64(log_message_modulus + 1, 0); + let mut have: i64 = pt_have.decode_coeff_i64((log_message_modulus + 1).into(), 0); // Get positive representative and assert equality have = (have + message_modulus as i64) % (message_modulus as i64); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs index a57d493..20adc6d 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs @@ -21,19 +21,19 @@ where + VecZnxRotateInplaceTmpBytes, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let basek: usize = 20; + let base2k: usize = 20; let k_lut: usize = 40; let message_modulus: usize = 16; let extension_factor: usize = 1; - let log_scale: usize = basek + 1; + let log_scale: usize = base2k + 1; let mut f: Vec = vec![0i64; message_modulus]; f.iter_mut() .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor); + let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); lut.set(module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; @@ -42,7 +42,7 @@ where let step: usize = lut.domain_size().div_round(message_modulus); let mut lut_dec: Vec = vec![0i64; module.n()]; - lut.data[0].decode_vec_i64(basek, 0, log_scale, &mut lut_dec); + lut.data[0].decode_vec_i64(base2k, 0, log_scale, &mut lut_dec); (0..lut.domain_size()).step_by(step).for_each(|i| { (0..step).for_each(|_| { @@ -61,19 +61,19 @@ where + VecZnxRotateInplaceTmpBytes, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let basek: usize = 20; + let base2k: usize = 20; let k_lut: usize = 40; let message_modulus: usize = 16; let extension_factor: usize = 4; - let log_scale: usize = basek + 1; + let log_scale: usize = base2k + 1; let mut f: Vec = vec![0i64; message_modulus]; f.iter_mut() .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor); + let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, k_lut, extension_factor); lut.set(module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; @@ -84,7 +84,7 @@ where let mut lut_dec: Vec = vec![0i64; module.n()]; (0..extension_factor).for_each(|ext| { - lut.data[ext].decode_vec_i64(basek, 0, log_scale, &mut lut_dec); + lut.data[ext].decode_vec_i64(base2k, 0, log_scale, &mut lut_dec); (0..module.n()).step_by(step).for_each(|i| { (0..step).for_each(|_| { assert_eq!(f[i / step] % message_modulus as i64, lut_dec[i]); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs index 15c6fdb..b56042a 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs @@ -1,15 +1,35 @@ use poulpy_hal::test_suite::serialization::test_reader_writer_interface; -use crate::tfhe::blind_rotation::{BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, CGGI}; +use crate::tfhe::blind_rotation::{ + BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyLayout, CGGI, +}; #[test] fn test_cggi_blind_rotation_key_serialization() { - let original: BlindRotationKey, CGGI> = BlindRotationKey::alloc(256, 64, 12, 54, 2, 2); + let layout: BlindRotationKeyLayout = BlindRotationKeyLayout { + n_glwe: 256_u32.into(), + n_lwe: 64_usize.into(), + base2k: 12_usize.into(), + k: 54_usize.into(), + rows: 2_usize.into(), + rank: 2_usize.into(), + }; + + let original: BlindRotationKey, CGGI> = BlindRotationKey::alloc(&layout); test_reader_writer_interface(original); } #[test] fn test_cggi_blind_rotation_key_compressed_serialization() { - let original: BlindRotationKeyCompressed, CGGI> = BlindRotationKeyCompressed::alloc(256, 64, 12, 54, 2, 2); + let layout: BlindRotationKeyLayout = BlindRotationKeyLayout { + n_glwe: 256_u32.into(), + n_lwe: 64_usize.into(), + base2k: 12_usize.into(), + k: 54_usize.into(), + rows: 2_usize.into(), + rank: 2_usize.into(), + }; + + let original: BlindRotationKeyCompressed, CGGI> = BlindRotationKeyCompressed::alloc(&layout); test_reader_writer_interface(original); } diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index 5497700..d462ff7 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -4,17 +4,20 @@ use poulpy_hal::{ api::{ ScratchAvailable, TakeMatZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAddInplace, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, - VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, + VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ToOwnedDeep}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; -use poulpy_core::{GLWEOperations, TakeGGLWE, TakeGLWECt, layouts::Infos}; +use poulpy_core::{ + GLWEOperations, TakeGGLWE, TakeGLWECt, + layouts::{Digits, GGLWECiphertextLayout, GGSWInfos, GLWEInfos, LWEInfos}, +}; use poulpy_core::layouts::{GGSWCiphertext, GLWECiphertext, LWECiphertext, prepared::GGLWEAutomorphismKeyPrepared}; @@ -39,7 +42,7 @@ where + VecZnxAddInplace + VecZnxNegateInplace + VecZnxCopy - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes @@ -50,11 +53,12 @@ where + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace + + VecZnxBigSubSmallNegateInplace + VecZnxRotateInplaceTmpBytes + VecZnxBigAllocBytes + VecZnxDftAddInplace - + VecZnxRotate, + + VecZnxRotate + + VecZnxNormalize, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, Scratch: TakeVecZnx + TakeVecZnxDftSlice @@ -138,7 +142,7 @@ pub fn circuit_bootstrap_core( + VecZnxAddInplace + VecZnxNegateInplace + VecZnxCopy - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes @@ -149,11 +153,12 @@ pub fn circuit_bootstrap_core( + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace + + VecZnxBigSubSmallNegateInplace + VecZnxBigAllocBytes + VecZnxDftAddInplace + VecZnxRotateInplaceTmpBytes - + VecZnxRotate, + + VecZnxRotate + + VecZnxNormalize, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, Scratch: TakeVecZnxDftSlice + TakeVecZnxBig @@ -166,16 +171,18 @@ pub fn circuit_bootstrap_core( { #[cfg(debug_assertions)] { + use poulpy_core::layouts::LWEInfos; + assert_eq!(res.n(), key.brk.n()); - assert_eq!(lwe.basek(), key.brk.basek()); - assert_eq!(res.basek(), key.brk.basek()); + assert_eq!(lwe.base2k(), key.brk.base2k()); + assert_eq!(res.base2k(), key.brk.base2k()); } - let n: usize = res.n(); - let basek: usize = res.basek(); - let rows: usize = res.rows(); - let rank: usize = res.rank(); - let k: usize = res.k(); + let n: usize = res.n().into(); + let base2k: usize = res.base2k().into(); + let rows: usize = res.rows().into(); + let rank: usize = res.rank().into(); + let k: usize = res.k().into(); let alpha: usize = rows.next_power_of_two(); @@ -183,27 +190,38 @@ pub fn circuit_bootstrap_core( if to_exponent { (0..rows).for_each(|i| { - f[i] = 1 << (basek * (rows - 1 - i)); + f[i] = 1 << (base2k * (rows - 1 - i)); }); } else { (0..1 << log_domain).for_each(|j| { (0..rows).for_each(|i| { - f[j * alpha + i] = j as i64 * (1 << (basek * (rows - 1 - i))); + f[j * alpha + i] = j as i64 * (1 << (base2k * (rows - 1 - i))); }); }); } // Lut precision, basically must be able to hold the decomposition power basis of the GGSW - let mut lut: LookUpTable = LookUpTable::alloc(module, basek, basek * rows, extension_factor); - lut.set(module, &f, basek * rows); + let mut lut: LookUpTable = LookUpTable::alloc(module, base2k, base2k * rows, extension_factor); + lut.set(module, &f, base2k * rows); if to_exponent { lut.set_rotation_direction(LookUpTableRotationDirection::Right); } // TODO: separate GGSW k from output of blind rotation k - let (mut res_glwe, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); - let (mut tmp_gglwe, scratch_2) = scratch_1.take_gglwe(n, basek, k, rows, 1, rank.max(1), rank); + let (mut res_glwe, scratch_1) = scratch.take_glwe_ct(res); + + let gglwe_infos: GGLWECiphertextLayout = GGLWECiphertextLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + rows: rows.into(), + digits: Digits(1), + rank_in: rank.max(1).into(), + rank_out: rank.into(), + }; + + let (mut tmp_gglwe, scratch_2) = scratch_1.take_gglwe(&gglwe_infos); key.brk.execute(module, &mut res_glwe, lwe, &lut, scratch_2); @@ -264,7 +282,7 @@ fn post_process( + VecZnxAddInplace + VecZnxNegateInplace + VecZnxCopy - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes @@ -275,8 +293,9 @@ fn post_process( + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace - + VecZnxRotate, + + VecZnxBigSubSmallNegateInplace + + VecZnxRotate + + VecZnxNormalize, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { let log_n: usize = module.log_n(); @@ -336,7 +355,7 @@ pub fn pack( + VecZnxAddInplace + VecZnxNegateInplace + VecZnxCopy - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes @@ -347,16 +366,13 @@ pub fn pack( + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace - + VecZnxRotate, + + VecZnxBigSubSmallNegateInplace + + VecZnxRotate + + VecZnxNormalize, Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, { let log_n: usize = module.log_n(); - let basek: usize = cts.get(&0).unwrap().basek(); - let k: usize = cts.get(&0).unwrap().k(); - let rank: usize = cts.get(&0).unwrap().rank(); - (0..log_n - log_gap_out).for_each(|i| { let t: usize = 16.min(1 << (log_n - 1 - i)); @@ -370,17 +386,7 @@ pub fn pack( let mut a: Option> = cts.remove(&j); let mut b: Option> = cts.remove(&(j + t)); - combine( - module, - basek, - k, - rank, - a.as_mut(), - b.as_mut(), - i, - auto_key, - scratch, - ); + combine(module, a.as_mut(), b.as_mut(), i, auto_key, scratch); if let Some(a) = a { cts.insert(j, a); @@ -394,9 +400,6 @@ pub fn pack( #[allow(clippy::too_many_arguments)] fn combine( module: &Module, - basek: usize, - k: usize, - rank: usize, a: Option<&mut GLWECiphertext>, b: Option<&mut GLWECiphertext>, i: usize, @@ -415,7 +418,7 @@ fn combine( + VecZnxAddInplace + VecZnxNegateInplace + VecZnxCopy - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes @@ -426,8 +429,9 @@ fn combine( + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace - + VecZnxRotate, + + VecZnxBigSubSmallNegateInplace + + VecZnxRotate + + VecZnxNormalize, Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, { // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) @@ -441,12 +445,10 @@ fn combine( // either mapped to garbage or twice their value which vanishes I(X) // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if let Some(a) = a { - let n: usize = a.n(); - let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _; - let t: i64 = 1 << (log_n - i - 1); + let t: i64 = 1 << (a.n().log2() - i - 1); if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); // a = a * X^-t a.rotate_inplace(module, -t, scratch_1); @@ -478,15 +480,13 @@ fn combine( a.automorphism_add_inplace(module, auto_key, scratch); } } else if let Some(b) = b { - let n: usize = b.n(); - let log_n: usize = (u64::BITS - (n - 1).leading_zeros()) as _; - let t: i64 = 1 << (log_n - i - 1); + let t: i64 = 1 << (b.n().log2() - i - 1); - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(b); tmp_b.rotate(module, t, b); tmp_b.rsh(module, 1, scratch_1); // a = (b* X^t - phi(b* X^t)) - b.automorphism_sub_ba(module, &tmp_b, auto_key, scratch_1); + b.automorphism_sub_negate(module, &tmp_b, auto_key, scratch_1); } } diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs index 69a5586..b9ed6b9 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs @@ -1,5 +1,6 @@ use poulpy_core::layouts::{ - GGLWEAutomorphismKey, GGLWETensorKey, GLWECiphertext, GLWESecret, LWESecret, + GGLWEAutomorphismKey, GGLWEAutomorphismKeyLayout, GGLWELayoutInfos, GGLWETensorKey, GGLWETensorKeyLayout, GGSWInfos, + GLWECiphertext, GLWEInfos, GLWESecret, LWEInfos, LWESecret, prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc}, }; use std::collections::HashMap; @@ -9,7 +10,7 @@ use poulpy_hal::{ ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, - VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Data, DataRef, Module, Scratch}, @@ -17,27 +18,49 @@ use poulpy_hal::{ }; use crate::tfhe::blind_rotation::{ - BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyPrepared, + BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, BlindRotationKeyInfos, + BlindRotationKeyLayout, BlindRotationKeyPrepared, }; +pub trait CircuitBootstrappingKeyInfos { + fn layout_brk(&self) -> BlindRotationKeyLayout; + fn layout_atk(&self) -> GGLWEAutomorphismKeyLayout; + fn layout_tsk(&self) -> GGLWETensorKeyLayout; +} + +pub struct CircuitBootstrappingKeyLayout { + pub layout_brk: BlindRotationKeyLayout, + pub layout_atk: GGLWEAutomorphismKeyLayout, + pub layout_tsk: GGLWETensorKeyLayout, +} + +impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyLayout { + fn layout_atk(&self) -> GGLWEAutomorphismKeyLayout { + self.layout_atk + } + + fn layout_brk(&self) -> BlindRotationKeyLayout { + self.layout_brk + } + + fn layout_tsk(&self) -> GGLWETensorKeyLayout { + self.layout_tsk + } +} + pub trait CircuitBootstrappingKeyEncryptSk { #[allow(clippy::too_many_arguments)] - fn encrypt_sk( + fn encrypt_sk( module: &Module, - basek: usize, sk_lwe: &LWESecret, sk_glwe: &GLWESecret, - k_brk: usize, - rows_brk: usize, - k_trace: usize, - rows_trace: usize, - k_tsk: usize, - rows_tsk: usize, + cbt_infos: &INFOS, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) -> Self where + INFOS: CircuitBootstrappingKeyInfos, DLwe: DataRef, DGlwe: DataRef; } @@ -61,7 +84,7 @@ where + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxAddNormal @@ -74,46 +97,41 @@ where + VecZnxAutomorphism, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeSvpPPol + TakeVecZnxBig, { - fn encrypt_sk( + fn encrypt_sk( module: &Module, - basek: usize, sk_lwe: &LWESecret, sk_glwe: &GLWESecret, - k_brk: usize, - rows_brk: usize, - k_trace: usize, - rows_trace: usize, - k_tsk: usize, - rows_tsk: usize, + cbt_infos: &INFOS, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) -> Self where + INFOS: CircuitBootstrappingKeyInfos, DLwe: DataRef, DGlwe: DataRef, Module:, { + assert_eq!(sk_lwe.n(), cbt_infos.layout_brk().n_lwe()); + assert_eq!(sk_glwe.n(), cbt_infos.layout_brk().n_glwe()); + assert_eq!(sk_glwe.n(), cbt_infos.layout_atk().n()); + assert_eq!(sk_glwe.n(), cbt_infos.layout_tsk().n()); + + let atk_infos: GGLWEAutomorphismKeyLayout = cbt_infos.layout_atk(); + let brk_infos: BlindRotationKeyLayout = cbt_infos.layout_brk(); + let trk_infos: GGLWETensorKeyLayout = cbt_infos.layout_tsk(); + let mut auto_keys: HashMap>> = HashMap::new(); let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); gal_els.iter().for_each(|gal_el| { - let mut key: GGLWEAutomorphismKey> = - GGLWEAutomorphismKey::alloc(sk_glwe.n(), basek, k_trace, rows_trace, 1, sk_glwe.rank()); + let mut key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(&atk_infos); key.encrypt_sk(module, *gal_el, sk_glwe, source_xa, source_xe, scratch); auto_keys.insert(*gal_el, key); }); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch); - let mut brk: BlindRotationKey, BRA> = BlindRotationKey::, BRA>::alloc( - sk_glwe.n(), - sk_lwe.n(), - basek, - k_brk, - rows_brk, - sk_glwe.rank(), - ); - + let mut brk: BlindRotationKey, BRA> = BlindRotationKey::, BRA>::alloc(&brk_infos); brk.encrypt_sk( module, &sk_glwe_prepared, @@ -123,7 +141,7 @@ where scratch, ); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(sk_glwe.n(), basek, k_tsk, rows_tsk, 1, sk_glwe.rank()); + let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(&trk_infos); tsk.encrypt_sk(module, sk_glwe, source_xa, source_xe, scratch); Self { @@ -140,6 +158,42 @@ pub struct CircuitBootstrappingKeyPrepared, B>>, } +impl CircuitBootstrappingKeyInfos for CircuitBootstrappingKeyPrepared { + fn layout_atk(&self) -> GGLWEAutomorphismKeyLayout { + let (_, atk) = self.atk.iter().next().expect("atk is empty"); + GGLWEAutomorphismKeyLayout { + n: atk.n(), + base2k: atk.base2k(), + k: atk.k(), + rows: atk.rows(), + digits: atk.digits(), + rank: atk.rank(), + } + } + + fn layout_brk(&self) -> BlindRotationKeyLayout { + BlindRotationKeyLayout { + n_glwe: self.brk.n_glwe(), + n_lwe: self.brk.n_lwe(), + base2k: self.brk.base2k(), + k: self.brk.k(), + rows: self.brk.rows(), + rank: self.brk.rank(), + } + } + + fn layout_tsk(&self) -> GGLWETensorKeyLayout { + GGLWETensorKeyLayout { + n: self.tsk.n(), + base2k: self.tsk.base2k(), + k: self.tsk.k(), + rows: self.tsk.rows(), + digits: self.tsk.digits(), + rank: self.tsk.rank(), + } + } +} + impl PrepareAlloc, BRA, B>> for CircuitBootstrappingKey where diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs index 8e543eb..7094255 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs @@ -5,10 +5,10 @@ use poulpy_hal::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAddInplace, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, }, @@ -23,14 +23,17 @@ use poulpy_hal::{ use crate::tfhe::{ blind_rotation::{ BlincRotationExecute, BlindRotationAlgo, BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyEncryptSk, - BlindRotationKeyPrepared, + BlindRotationKeyLayout, BlindRotationKeyPrepared, }, circuit_bootstrapping::{ - CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute, + CircuitBootstrappingKey, CircuitBootstrappingKeyEncryptSk, CircuitBootstrappingKeyLayout, + CircuitBootstrappingKeyPrepared, CirtuitBootstrappingExecute, }, }; -use poulpy_core::layouts::prepared::PrepareAlloc; +use poulpy_core::layouts::{ + Digits, GGLWEAutomorphismKeyLayout, GGLWETensorKeyLayout, GGSWCiphertextLayout, LWECiphertextLayout, prepared::PrepareAlloc, +}; use poulpy_core::layouts::{ GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, LWECiphertext, LWEPlaintext, LWESecret, @@ -48,7 +51,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalize + VecZnxSub @@ -78,7 +81,7 @@ where + VecZnxNegateInplace + VecZnxCopy + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace + + VecZnxBigSubSmallNegateInplace + VecZnxRotateInplaceTmpBytes + VecZnxBigAllocBytes + VecZnxDftAddInplace @@ -102,8 +105,8 @@ where BlindRotationKeyPrepared, BRA, B>: BlincRotationExecute, BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, { - let n: usize = module.n(); - let basek: usize = 17; + let n_glwe: usize = module.n(); + let base2k: usize = 17; let extension_factor: usize = 1; let rank: usize = 1; @@ -112,61 +115,97 @@ where let k_lwe_ct: usize = 22; let block_size: usize = 7; - let k_brk: usize = 5 * basek; + let k_brk: usize = 5 * base2k; let rows_brk: usize = 4; - let k_trace: usize = 5 * basek; - let rows_trace: usize = 4; + let k_atk: usize = 5 * base2k; + let rows_atk: usize = 4; - let k_tsk: usize = 5 * basek; + let k_tsk: usize = 5 * base2k; let rows_tsk: usize = 4; + let k_ggsw_res: usize = 4 * base2k; + let rows_ggsw_res: usize = 2; + + let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + n: n_lwe.into(), + k: k_lwe_ct.into(), + base2k: base2k.into(), + }; + + let cbt_infos: CircuitBootstrappingKeyLayout = CircuitBootstrappingKeyLayout { + layout_brk: BlindRotationKeyLayout { + n_glwe: n_glwe.into(), + n_lwe: n_lwe.into(), + base2k: base2k.into(), + k: k_brk.into(), + rows: rows_brk.into(), + rank: rank.into(), + }, + layout_atk: GGLWEAutomorphismKeyLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_atk.into(), + rows: rows_atk.into(), + rank: rank.into(), + digits: Digits(1), + }, + layout_tsk: GGLWETensorKeyLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_tsk.into(), + rows: rows_tsk.into(), + digits: Digits(1), + rank: rank.into(), + }, + }; + + let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_ggsw_res.into(), + rows: rows_ggsw_res.into(), + digits: Digits(1), + rank: rank.into(), + }; + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 23); let mut source_xs: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([1u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); - let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe.into(), rank.into()); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); let data: i64 = 1; - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); - pt_lwe.encode_i64(data, k_lwe_pt + 1); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); - println!("pt_lwe: {}", pt_lwe); + println!("pt_lwe: {pt_lwe}"); - let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); let now: Instant = Instant::now(); let cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::encrypt_sk( module, - basek, &sk_lwe, &sk_glwe, - k_brk, - rows_brk, - k_trace, - rows_trace, - k_tsk, - rows_tsk, + &cbt_infos, &mut source_xa, &mut source_xe, scratch.borrow(), ); println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); - let k_ggsw_res: usize = 4 * basek; - let rows_ggsw_res: usize = 2; - - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw_res, rows_ggsw_res, 1, rank); + let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); let log_gap_out = 1; @@ -185,7 +224,7 @@ where println!("CBT: {} ms", now.elapsed().as_millis()); // X^{data * 2^log_gap_out} - let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n_glwe, 1); pt_ggsw.at_mut(0, 0)[0] = 1; module.vec_znx_rotate_inplace( data * (1 << log_gap_out), @@ -196,11 +235,9 @@ where res.print_noise(module, &sk_glwe_prepared, &pt_ggsw); - let k_glwe: usize = k_ggsw_res; - - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_glwe, rank); - let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, basek); - pt_glwe.data.at_mut(0, 0)[0] = 1 << (basek - 2); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&ggsw_infos); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k - 2); ct_glwe.encrypt_sk( module, @@ -215,7 +252,7 @@ where ct_glwe.external_product_inplace(module, &res_prepared, scratch.borrow()); - let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_glwe); + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); ct_glwe.decrypt(module, &mut pt_res, &sk_glwe_prepared, scratch.borrow()); // Parameters are set such that the first limb should be noiseless. @@ -235,7 +272,7 @@ where + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes - + VecZnxSubABInplace + + VecZnxSubInplace + VecZnxAddInplace + VecZnxNormalize + VecZnxSub @@ -266,7 +303,7 @@ where + VecZnxNegateInplace + VecZnxCopy + VecZnxAutomorphismInplace - + VecZnxBigSubSmallBInplace + + VecZnxBigSubSmallNegateInplace + VecZnxBigAllocBytes + VecZnxDftAddInplace + VecZnxRotate @@ -289,8 +326,8 @@ where BlindRotationKeyPrepared, BRA, B>: BlincRotationExecute, BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, { - let n: usize = module.n(); - let basek: usize = 14; + let n_glwe: usize = module.n(); + let base2k: usize = 14; let extension_factor: usize = 1; let rank: usize = 2; @@ -299,61 +336,97 @@ where let k_lwe_ct: usize = 13; let block_size: usize = 7; - let k_brk: usize = 5 * basek; + let k_brk: usize = 5 * base2k; let rows_brk: usize = 3; - let k_trace: usize = 5 * basek; - let rows_trace: usize = 4; + let k_atk: usize = 5 * base2k; + let rows_atk: usize = 4; - let k_tsk: usize = 5 * basek; + let k_tsk: usize = 5 * base2k; let rows_tsk: usize = 4; + let k_ggsw_res: usize = 4 * base2k; + let rows_ggsw_res: usize = 3; + + let lwe_infos: LWECiphertextLayout = LWECiphertextLayout { + n: n_lwe.into(), + k: k_lwe_ct.into(), + base2k: base2k.into(), + }; + + let cbt_infos: CircuitBootstrappingKeyLayout = CircuitBootstrappingKeyLayout { + layout_brk: BlindRotationKeyLayout { + n_glwe: n_glwe.into(), + n_lwe: n_lwe.into(), + base2k: base2k.into(), + k: k_brk.into(), + rows: rows_brk.into(), + rank: rank.into(), + }, + layout_atk: GGLWEAutomorphismKeyLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_atk.into(), + rows: rows_atk.into(), + rank: rank.into(), + digits: Digits(1), + }, + layout_tsk: GGLWETensorKeyLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_tsk.into(), + rows: rows_tsk.into(), + digits: Digits(1), + rank: rank.into(), + }, + }; + + let ggsw_infos: GGSWCiphertextLayout = GGSWCiphertextLayout { + n: n_glwe.into(), + base2k: base2k.into(), + k: k_ggsw_res.into(), + rows: rows_ggsw_res.into(), + digits: Digits(1), + rank: rank.into(), + }; + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 23); let mut source_xs: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([1u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); - let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); + let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe.into()); sk_lwe.fill_binary_block(block_size, &mut source_xs); - let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_with(n_glwe.into(), rank.into()); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); let sk_glwe_prepared: GLWESecretPrepared, B> = sk_glwe.prepare_alloc(module, scratch.borrow()); let data: i64 = 1; - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe_pt); - pt_lwe.encode_i64(data, k_lwe_pt + 1); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc_with(base2k.into(), k_lwe_pt.into()); + pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); - println!("pt_lwe: {}", pt_lwe); + println!("pt_lwe: {pt_lwe}"); - let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(n_lwe, basek, k_lwe_ct); + let mut ct_lwe: LWECiphertext> = LWECiphertext::alloc(&lwe_infos); ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); let now: Instant = Instant::now(); let cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::encrypt_sk( module, - basek, &sk_lwe, &sk_glwe, - k_brk, - rows_brk, - k_trace, - rows_trace, - k_tsk, - rows_tsk, + &cbt_infos, &mut source_xa, &mut source_xe, scratch.borrow(), ); println!("CBT-KGEN: {} ms", now.elapsed().as_millis()); - let k_ggsw_res: usize = 4 * basek; - let rows_ggsw_res: usize = 3; - - let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw_res, rows_ggsw_res, 1, rank); + let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(&ggsw_infos); let cbt_prepared: CircuitBootstrappingKeyPrepared, BRA, B> = cbt_key.prepare_alloc(module, scratch.borrow()); @@ -369,16 +442,14 @@ where println!("CBT: {} ms", now.elapsed().as_millis()); // X^{data * 2^log_gap_out} - let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n_glwe, 1); pt_ggsw.at_mut(0, 0)[0] = data; res.print_noise(module, &sk_glwe_prepared, &pt_ggsw); - let k_glwe: usize = k_ggsw_res; - - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_glwe, rank); - let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, basek); - pt_glwe.data.at_mut(0, 0)[0] = 1 << (basek - k_lwe_pt - 1); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(&ggsw_infos); + let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); + pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k - k_lwe_pt - 1); ct_glwe.encrypt_sk( module, @@ -393,7 +464,7 @@ where ct_glwe.external_product_inplace(module, &res_prepared, scratch.borrow()); - let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_glwe); + let mut pt_res: GLWEPlaintext> = GLWEPlaintext::alloc(&ggsw_infos); ct_glwe.decrypt(module, &mut pt_res, &sk_glwe_prepared, scratch.borrow()); // Parameters are set such that the first limb should be noiseless.