diff --git a/Cargo.lock b/Cargo.lock index 5f72436..b68462f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,12 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "bytemuck" +version = "1.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677" + [[package]] name = "byteorder" version = "1.5.0" @@ -307,9 +313,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.2" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "oorandom" @@ -353,6 +359,7 @@ dependencies = [ "cmake", "criterion", "itertools 0.14.0", + "once_cell", "poulpy-hal", "rand", "rand_chacha", @@ -368,6 +375,7 @@ dependencies = [ "byteorder", "criterion", "itertools 0.14.0", + "once_cell", "poulpy-backend", "poulpy-hal", "rug", @@ -377,10 +385,12 @@ dependencies = [ name = "poulpy-hal" version = "0.1.2" dependencies = [ + "bytemuck", "byteorder", "cmake", "criterion", "itertools 0.14.0", + "once_cell", "rand", "rand_chacha", "rand_core", diff --git a/Cargo.toml b/Cargo.toml index 2b08242..affc467 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,4 @@ itertools = "0.14.0" criterion = "0.7.0" byteorder = "1.5.0" zstd = "0.13.3" +once_cell = "1.21.3" \ No newline at end of file diff --git a/poulpy-backend/Cargo.toml b/poulpy-backend/Cargo.toml index 0838b01..5820fb5 100644 --- a/poulpy-backend/Cargo.toml +++ b/poulpy-backend/Cargo.toml @@ -18,6 +18,7 @@ rand = {workspace = true} rand_distr = {workspace = true} rand_core = {workspace = true} byteorder = {workspace = true} +once_cell = {workspace = true} rand_chacha = "0.9.0" [build-dependencies] @@ -25,4 +26,9 @@ cmake = "0.1.54" [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] \ No newline at end of file +rustdoc-args = ["--cfg", "docsrs"] + + +[[bench]] +name = "vmp" +harness = false \ No newline at end of file diff --git a/poulpy-backend/benches/fft.rs b/poulpy-backend/benches/fft.rs new file mode 100644 index 0000000..7f0f6af --- /dev/null +++ b/poulpy-backend/benches/fft.rs @@ -0,0 +1,224 @@ +use std::{ffi::c_void, hint::black_box}; + +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use poulpy_backend::cpu_spqlios::reim; +use poulpy_hal::reference::fft64::reim::{ReimDFTExecute, ReimFFTRef, ReimFFTTable, ReimIFFTRef, ReimIFFTTable}; + +pub fn bench_fft_ref(c: &mut Criterion) { + let group_name: String = "fft_ref".to_string(); + + let mut group = c.benchmark_group(group_name); + + fn runner(m: usize) -> impl FnMut() { + let mut values: Vec = vec![0f64; m << 1]; + let scale: f64 = 1.0f64 / (2 * m) as f64; + values + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + let table: ReimFFTTable = ReimFFTTable::::new(m); + move || { + ReimFFTRef::reim_dft_execute(&table, &mut values); + black_box(()); + } + } + + for log_m in [9, 10, 11, 12, 13, 14, 15] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m)); + let mut runner = runner(1 << log_m); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_fft_avx2_fma(c: &mut Criterion) { + let group_name: String = "fft_avx2_fma".to_string(); + + let mut group = c.benchmark_group(group_name); + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + #[target_feature(enable = "avx2,fma")] + fn runner(m: usize) -> impl FnMut() { + let mut values: Vec = vec![0f64; m << 1]; + + let scale = 1.0f64 / (2 * m) as f64; + values + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + + let table: ReimFFTTable = ReimFFTTable::::new(m); + move || { + use poulpy_backend::cpu_fft64_avx::ReimFFTAvx; + + ReimFFTAvx::reim_dft_execute(&table, &mut values); + black_box(()); + } + } + + if std::is_x86_feature_detected!("avx2") { + for log_m in [9, 10, 11, 12, 13, 14, 15] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m)); + unsafe { + let mut runner = runner(1 << log_m); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + } + } else { + eprintln!("skipping: CPU lacks avx2"); + return; + } + + group.finish(); +} + +pub fn bench_fft_spqlios(c: &mut Criterion) { + let group_name: String = "fft_spqlios".to_string(); + + let mut group = c.benchmark_group(group_name); + + fn runner(m: usize) -> impl FnMut() { + let mut values: Vec = vec![0f64; m << 1]; + + let scale = 1.0f64 / (2 * m) as f64; + values + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + + unsafe { + reim::reim_fft_simple(m as u32, values.as_mut_ptr() as *mut c_void); + } + + move || { + unsafe { + reim::reim_fft_simple(m as u32, values.as_mut_ptr() as *mut c_void); + } + black_box(()); + } + } + + for log_m in [9, 10, 11, 12, 13, 14, 15] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m)); + let mut runner = runner(1 << log_m); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_ifft_ref(c: &mut Criterion) { + let group_name: String = "ifft_ref".to_string(); + + let mut group = c.benchmark_group(group_name); + + fn runner(m: usize) -> impl FnMut() { + let mut values: Vec = vec![0f64; m << 1]; + let scale: f64 = 1.0f64 / (2 * m) as f64; + values + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + let table: ReimIFFTTable = ReimIFFTTable::::new(m); + move || { + ReimIFFTRef::reim_dft_execute(&table, &mut values); + black_box(()); + } + } + + for log_m in [9, 10, 11, 12, 13, 14, 15] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m)); + let mut runner = runner(1 << log_m); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_ifft_avx2_fma(c: &mut Criterion) { + let group_name: String = "ifft_avx2_fma".to_string(); + + let mut group = c.benchmark_group(group_name); + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + #[target_feature(enable = "avx2,fma")] + fn runner(m: usize) -> impl FnMut() { + let mut values: Vec = vec![0f64; m << 1]; + + let scale = 1.0f64 / (2 * m) as f64; + values + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + + let table: ReimIFFTTable = ReimIFFTTable::::new(m); + move || { + use poulpy_backend::cpu_fft64_avx::ReimIFFTAvx; + + ReimIFFTAvx::reim_dft_execute(&table, &mut values); + black_box(()); + } + } + + if std::is_x86_feature_detected!("avx2") { + for log_m in [9, 10, 11, 12, 13, 14, 15] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m)); + unsafe { + let mut runner = runner(1 << log_m); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + } + } else { + eprintln!("skipping: CPU lacks avx2"); + return; + } + + group.finish(); +} + +pub fn bench_ifft_spqlios(c: &mut Criterion) { + let group_name: String = "ifft_spqlios".to_string(); + + let mut group = c.benchmark_group(group_name); + + fn runner(m: usize) -> impl FnMut() { + let mut values: Vec = vec![0f64; m << 1]; + + let scale = 1.0f64 / (2 * m) as f64; + values + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + + unsafe { + reim::reim_ifft_simple(m as u32, values.as_mut_ptr() as *mut c_void); + } + + move || { + unsafe { + reim::reim_ifft_simple(m as u32, values.as_mut_ptr() as *mut c_void); + } + black_box(()); + } + } + + for log_m in [9, 10, 11, 12, 13, 14, 15] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m)); + let mut runner = runner(1 << log_m); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_fft_ref, + bench_fft_avx2_fma, + bench_fft_spqlios, + bench_ifft_ref, + bench_ifft_avx2_fma, + bench_ifft_spqlios +); +criterion_main!(benches); diff --git a/poulpy-backend/benches/vec_znx.rs b/poulpy-backend/benches/vec_znx.rs new file mode 100644 index 0000000..8dafd16 --- /dev/null +++ b/poulpy-backend/benches/vec_znx.rs @@ -0,0 +1,43 @@ +// poulpy-backend/benches/vec_znx_add.rs +use criterion::{Criterion, criterion_group, criterion_main}; +use poulpy_backend::{cpu_fft64_ref, cpu_spqlios}; +use poulpy_hal::reference::vec_znx::{bench_vec_znx_add, bench_vec_znx_automorphism, bench_vec_znx_normalize_inplace}; + +#[allow(dead_code)] +fn bench_vec_znx_add_cpu_spqlios_fft64(c: &mut Criterion) { + bench_vec_znx_add::(c, "cpu_spqlios::fft64"); +} + +#[allow(dead_code)] +fn bench_vec_znx_add_cpu_ref_fft64(c: &mut Criterion) { + bench_vec_znx_add::(c, "cpu_spqlios::fft64"); +} + +#[allow(dead_code)] +fn bench_vec_znx_normalize_inplace_cpu_ref_fft64(c: &mut Criterion) { + bench_vec_znx_normalize_inplace::(c, "cpu_ref::fft64"); +} + +#[allow(dead_code)] +fn bench_vec_znx_normalize_inplace_cpu_spqlios_fft64(c: &mut Criterion) { + bench_vec_znx_normalize_inplace::(c, "cpu_spqlios::fft64"); +} + +fn bench_vec_znx_automorphism_cpu_ref_fft64(c: &mut Criterion) { + bench_vec_znx_automorphism::(c, "cpu_ref::fft64"); +} + +fn bench_vec_znx_automorphism_cpu_spqlios_fft64(c: &mut Criterion) { + bench_vec_znx_automorphism::(c, "cpu_spqlios::fft64"); +} + +criterion_group!( + benches, + // bench_vec_znx_add_cpu_spqlios_fft64, + // bench_vec_znx_add_cpu_ref_fft64, + // bench_vec_znx_normalize_inplace_cpu_ref_fft64, + // bench_vec_znx_normalize_inplace_cpu_spqlios_fft64, + bench_vec_znx_automorphism_cpu_ref_fft64, + bench_vec_znx_automorphism_cpu_spqlios_fft64, +); +criterion_main!(benches); diff --git a/poulpy-backend/benches/vmp.rs b/poulpy-backend/benches/vmp.rs new file mode 100644 index 0000000..7f6d068 --- /dev/null +++ b/poulpy-backend/benches/vmp.rs @@ -0,0 +1,24 @@ +// poulpy-backend/benches/vec_znx_add.rs +use criterion::{Criterion, criterion_group, criterion_main}; +use poulpy_backend::{FFT64Avx, FFT64Ref, FFT64Spqlios}; +use poulpy_hal::bench_suite::vmp::bench_vmp_apply_dft_to_dft; + +fn bench_vmp_apply_dft_to_dft_cpu_spqlios_fft64(c: &mut Criterion) { + bench_vmp_apply_dft_to_dft::(c, "cpu_spqlios::fft64"); +} + +fn bench_vmp_apply_dft_to_dft_cpu_ref_fft64(c: &mut Criterion) { + bench_vmp_apply_dft_to_dft::(c, "cpu_ref::fft64"); +} + +fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(c: &mut Criterion) { + bench_vmp_apply_dft_to_dft::(c, "cpu_avx::fft64"); +} + +criterion_group!( + benches, + bench_vmp_apply_dft_to_dft_cpu_spqlios_fft64, + bench_vmp_apply_dft_to_dft_cpu_ref_fft64, + bench_vmp_apply_dft_to_dft_cpu_avx_fft64, +); +criterion_main!(benches); diff --git a/poulpy-backend/examples/rlwe_encrypt.rs b/poulpy-backend/examples/rlwe_encrypt.rs index ec9b066..63a526a 100644 --- a/poulpy-backend/examples/rlwe_encrypt.rs +++ b/poulpy-backend/examples/rlwe_encrypt.rs @@ -1,10 +1,10 @@ use itertools::izip; -use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_hal::{ api::{ - DFT, IDFTTmpA, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, - VecZnxAddNormal, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxBigSubSmallBInplace, VecZnxDftAlloc, VecZnxFillUniform, VecZnxNormalizeInplace, + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal, + VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, + VecZnxDftAlloc, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace, }, layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos}, source::Source, @@ -16,9 +16,9 @@ fn main() { let ct_size: usize = 3; let msg_size: usize = 2; let log_scale: usize = msg_size * basek - 5; - let module: Module = Module::::new(n as u64); + let module: Module = Module::::new(n as u64); - let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes()); + let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -28,7 +28,7 @@ fn main() { s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: SvpPPol, FFT64> = module.svp_ppol_alloc(s.cols()); + let mut s_dft: SvpPPol, FFT64Spqlios> = module.svp_ppol_alloc(s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); @@ -41,14 +41,14 @@ fn main() { ); // Fill the second column with random values: ct = (0, a) - module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source); + module.vec_znx_fill_uniform(basek, &mut ct, 1, &mut source); - let mut buf_dft: VecZnxDft, FFT64> = module.vec_znx_dft_alloc(1, ct_size); + let mut buf_dft: VecZnxDft, FFT64Spqlios> = module.vec_znx_dft_alloc(1, ct_size); - module.dft(1, 0, &mut buf_dft, 0, &ct, 1); + module.vec_znx_dft_apply(1, 0, &mut buf_dft, 0, &ct, 1); // Applies DFT(ct[1]) * DFT(s) - module.svp_apply_inplace( + module.svp_apply_dft_to_dft_inplace( &mut buf_dft, // DFT(ct[1] * s) 0, // Selects the first column of res &s_dft, // DFT(s) @@ -58,8 +58,8 @@ fn main() { // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - let mut buf_big: VecZnxBig, FFT64> = module.vec_znx_big_alloc(1, ct_size); - module.idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + let mut buf_big: VecZnxBig, FFT64Spqlios> = module.vec_znx_big_alloc(1, ct_size); + module.vec_znx_idft_apply_tmpa(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column let mut m = VecZnx::alloc( @@ -109,8 +109,8 @@ fn main() { // Decryption // DFT(ct[1] * s) - module.dft(1, 0, &mut buf_dft, 0, &ct, 1); - module.svp_apply_inplace( + module.vec_znx_dft_apply(1, 0, &mut buf_dft, 0, &ct, 1); + module.svp_apply_dft_to_dft_inplace( &mut buf_dft, 0, // Selects the first column of res. &s_dft, @@ -118,7 +118,7 @@ fn main() { ); // BIG(c1 * s) = IDFT(DFT(c1 * s)) - module.idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + module.vec_znx_idft_apply_tmpa(&mut buf_big, 0, &mut buf_dft, 0); // BIG(c1 * s) + ct[0] module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); diff --git a/poulpy-backend/src/cpu_fft64_avx/mod.rs b/poulpy-backend/src/cpu_fft64_avx/mod.rs new file mode 100644 index 0000000..c64d3df --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/mod.rs @@ -0,0 +1,18 @@ +mod module; +mod reim; +mod reim4; +mod scratch; +mod svp; +mod vec_znx; +mod vec_znx_big; +mod vec_znx_dft; +mod vmp; +mod zn; +mod znx_avx; + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +pub struct FFT64Avx {} +pub use reim::*; + +#[cfg(test)] +pub mod tests; diff --git a/poulpy-backend/src/cpu_fft64_avx/module.rs b/poulpy-backend/src/cpu_fft64_avx/module.rs new file mode 100644 index 0000000..c90d74c --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/module.rs @@ -0,0 +1,478 @@ +use std::ptr::NonNull; + +use poulpy_hal::{ + layouts::{Backend, Module}, + oep::ModuleNewImpl, + reference::{ + fft64::{ + reim::{ + ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul, + ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, + ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref, + }, + reim4::{ + Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks, + }, + }, + znx::{ + ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep, + ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, + ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, + ZnxSubABInplace, ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_rotate, znx_zero_ref, + }, + }, +}; + +use crate::cpu_fft64_avx::{ + FFT64Avx, + reim::{ + ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma, + reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma, + reim_sub_ab_inplace_avx2_fma, reim_sub_avx2_fma, reim_sub_ba_inplace_avx2_fma, reim_to_znx_i64_inplace_bnd63_avx2_fma, + }, + reim_to_znx_i64_bnd63_avx2_fma, + reim4::{ + reim4_extract_1blk_from_reim_avx, reim4_save_1blk_to_reim_avx, reim4_save_2blk_to_reim_avx, + reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx, + }, + znx_avx::{ + znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_negate_avx, znx_negate_inplace_avx, + znx_normalize_final_step_avx, znx_normalize_final_step_inplace_avx, znx_normalize_first_step_avx, + znx_normalize_first_step_carry_only_avx, znx_normalize_first_step_inplace_avx, znx_normalize_middle_step_avx, + znx_normalize_middle_step_carry_only_avx, znx_normalize_middle_step_inplace_avx, znx_sub_ab_inplace_avx, znx_sub_avx, + znx_sub_ba_inplace_avx, znx_switch_ring_avx, + }, +}; + +#[repr(C)] +pub struct FFT64AvxHandle { + table_fft: ReimFFTTable, + table_ifft: ReimIFFTTable, +} + +impl Backend for FFT64Avx { + type ScalarPrep = f64; + type ScalarBig = i64; + type Handle = FFT64AvxHandle; + unsafe fn destroy(handle: NonNull) { + unsafe { + drop(Box::from_raw(handle.as_ptr())); + } + } + + fn layout_big_word_count() -> usize { + 1 + } + + fn layout_prep_word_count() -> usize { + 1 + } +} + +unsafe impl ModuleNewImpl for FFT64Avx { + fn new_impl(n: u64) -> Module { + if !std::arch::is_x86_feature_detected!("avx") + || !std::arch::is_x86_feature_detected!("avx2") + || !std::arch::is_x86_feature_detected!("fma") + { + panic!("arch must support avx2, avx and fma") + } + + let handle: FFT64AvxHandle = FFT64AvxHandle { + table_fft: ReimFFTTable::new(n as usize >> 1), + table_ifft: ReimIFFTTable::new(n as usize >> 1), + }; + // Leak Box to get a stable NonNull pointer + let ptr: NonNull = NonNull::from(Box::leak(Box::new(handle))); + unsafe { Module::from_nonnull(ptr, n) } + } +} + +pub trait FFT64ModuleHandle { + fn get_fft_table(&self) -> &ReimFFTTable; + fn get_ifft_table(&self) -> &ReimIFFTTable; +} + +impl FFT64ModuleHandle for Module { + fn get_fft_table(&self) -> &ReimFFTTable { + let h: &FFT64AvxHandle = unsafe { &*self.ptr() }; + &h.table_fft + } + fn get_ifft_table(&self) -> &ReimIFFTTable { + let h: &FFT64AvxHandle = unsafe { &*self.ptr() }; + &h.table_ifft + } +} + +impl ZnxAdd for FFT64Avx { + #[inline(always)] + fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) { + unsafe { + znx_add_avx(res, a, b); + } + } +} + +impl ZnxAddInplace for FFT64Avx { + #[inline(always)] + fn znx_add_inplace(res: &mut [i64], a: &[i64]) { + unsafe { + znx_add_inplace_avx(res, a); + } + } +} + +impl ZnxSub for FFT64Avx { + #[inline(always)] + fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) { + unsafe { + znx_sub_avx(res, a, b); + } + } +} + +impl ZnxSubABInplace for FFT64Avx { + #[inline(always)] + fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) { + unsafe { + znx_sub_ab_inplace_avx(res, a); + } + } +} + +impl ZnxSubBAInplace for FFT64Avx { + #[inline(always)] + fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) { + unsafe { + znx_sub_ba_inplace_avx(res, a); + } + } +} + +impl ZnxAutomorphism for FFT64Avx { + #[inline(always)] + fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) { + unsafe { + znx_automorphism_avx(p, res, a); + } + } +} + +impl ZnxCopy for FFT64Avx { + #[inline(always)] + fn znx_copy(res: &mut [i64], a: &[i64]) { + znx_copy_ref(res, a); + } +} + +impl ZnxNegate for FFT64Avx { + #[inline(always)] + fn znx_negate(res: &mut [i64], src: &[i64]) { + unsafe { + znx_negate_avx(res, src); + } + } +} + +impl ZnxNegateInplace for FFT64Avx { + #[inline(always)] + fn znx_negate_inplace(res: &mut [i64]) { + unsafe { + znx_negate_inplace_avx(res); + } + } +} + +impl ZnxRotate for FFT64Avx { + #[inline(always)] + fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) { + znx_rotate::(p, res, src); + } +} + +impl ZnxZero for FFT64Avx { + #[inline(always)] + fn znx_zero(res: &mut [i64]) { + znx_zero_ref(res); + } +} + +impl ZnxSwitchRing for FFT64Avx { + #[inline(always)] + fn znx_switch_ring(res: &mut [i64], a: &[i64]) { + unsafe { + znx_switch_ring_avx(res, a); + } + } +} + +impl ZnxNormalizeFinalStep for FFT64Avx { + #[inline(always)] + fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + unsafe { + znx_normalize_final_step_avx(basek, lsh, x, a, carry); + } + } +} + +impl ZnxNormalizeFinalStepInplace for FFT64Avx { + #[inline(always)] + fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + unsafe { + znx_normalize_final_step_inplace_avx(basek, lsh, x, carry); + } + } +} + +impl ZnxNormalizeFirstStep for FFT64Avx { + #[inline(always)] + fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + unsafe { + znx_normalize_first_step_avx(basek, lsh, x, a, carry); + } + } +} + +impl ZnxNormalizeFirstStepCarryOnly for FFT64Avx { + #[inline(always)] + fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + unsafe { + znx_normalize_first_step_carry_only_avx(basek, lsh, x, carry); + } + } +} + +impl ZnxNormalizeFirstStepInplace for FFT64Avx { + #[inline(always)] + fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + unsafe { + znx_normalize_first_step_inplace_avx(basek, lsh, x, carry); + } + } +} + +impl ZnxNormalizeMiddleStep for FFT64Avx { + #[inline(always)] + fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + unsafe { + znx_normalize_middle_step_avx(basek, lsh, x, a, carry); + } + } +} + +impl ZnxNormalizeMiddleStepCarryOnly for FFT64Avx { + #[inline(always)] + fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + unsafe { + znx_normalize_middle_step_carry_only_avx(basek, lsh, x, carry); + } + } +} + +impl ZnxNormalizeMiddleStepInplace for FFT64Avx { + #[inline(always)] + fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + unsafe { + znx_normalize_middle_step_inplace_avx(basek, lsh, x, carry); + } + } +} + +impl ReimDFTExecute, f64> for FFT64Avx { + #[inline(always)] + fn reim_dft_execute(table: &ReimFFTTable, data: &mut [f64]) { + ReimFFTAvx::reim_dft_execute(table, data); + } +} + +impl ReimDFTExecute, f64> for FFT64Avx { + #[inline(always)] + fn reim_dft_execute(table: &ReimIFFTTable, data: &mut [f64]) { + ReimIFFTAvx::reim_dft_execute(table, data); + } +} + +impl ReimFromZnx for FFT64Avx { + #[inline(always)] + fn reim_from_znx(res: &mut [f64], a: &[i64]) { + unsafe { + reim_from_znx_i64_bnd50_fma(res, a); + } + } +} + +impl ReimToZnx for FFT64Avx { + #[inline(always)] + fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]) { + unsafe { + reim_to_znx_i64_bnd63_avx2_fma(res, divisor, a); + } + } +} + +impl ReimToZnxInplace for FFT64Avx { + #[inline(always)] + fn reim_to_znx_inplace(res: &mut [f64], divisor: f64) { + unsafe { + reim_to_znx_i64_inplace_bnd63_avx2_fma(res, divisor); + } + } +} + +impl ReimAdd for FFT64Avx { + #[inline(always)] + fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]) { + unsafe { + reim_add_avx2_fma(res, a, b); + } + } +} + +impl ReimAddInplace for FFT64Avx { + #[inline(always)] + fn reim_add_inplace(res: &mut [f64], a: &[f64]) { + unsafe { + reim_add_inplace_avx2_fma(res, a); + } + } +} + +impl ReimSub for FFT64Avx { + #[inline(always)] + fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]) { + unsafe { + reim_sub_avx2_fma(res, a, b); + } + } +} + +impl ReimSubABInplace for FFT64Avx { + #[inline(always)] + fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) { + unsafe { + reim_sub_ab_inplace_avx2_fma(res, a); + } + } +} + +impl ReimSubBAInplace for FFT64Avx { + #[inline(always)] + fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) { + unsafe { + reim_sub_ba_inplace_avx2_fma(res, a); + } + } +} + +impl ReimNegate for FFT64Avx { + #[inline(always)] + fn reim_negate(res: &mut [f64], a: &[f64]) { + unsafe { + reim_negate_avx2_fma(res, a); + } + } +} + +impl ReimNegateInplace for FFT64Avx { + #[inline(always)] + fn reim_negate_inplace(res: &mut [f64]) { + unsafe { + reim_negate_inplace_avx2_fma(res); + } + } +} + +impl ReimMul for FFT64Avx { + #[inline(always)] + fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]) { + unsafe { + reim_mul_avx2_fma(res, a, b); + } + } +} + +impl ReimMulInplace for FFT64Avx { + #[inline(always)] + fn reim_mul_inplace(res: &mut [f64], a: &[f64]) { + unsafe { + reim_mul_inplace_avx2_fma(res, a); + } + } +} + +impl ReimAddMul for FFT64Avx { + #[inline(always)] + fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]) { + unsafe { + reim_addmul_avx2_fma(res, a, b); + } + } +} + +impl ReimCopy for FFT64Avx { + #[inline(always)] + fn reim_copy(res: &mut [f64], a: &[f64]) { + reim_copy_ref(res, a); + } +} + +impl ReimZero for FFT64Avx { + #[inline(always)] + fn reim_zero(res: &mut [f64]) { + reim_zero_ref(res); + } +} + +impl Reim4Extract1Blk for FFT64Avx { + #[inline(always)] + fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + unsafe { + reim4_extract_1blk_from_reim_avx(m, rows, blk, dst, src); + } + } +} + +impl Reim4Save1Blk for FFT64Avx { + #[inline(always)] + fn reim4_save_1blk(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + unsafe { + reim4_save_1blk_to_reim_avx::(m, blk, dst, src); + } + } +} + +impl Reim4Save2Blks for FFT64Avx { + #[inline(always)] + fn reim4_save_2blks(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + unsafe { + reim4_save_2blk_to_reim_avx::(m, blk, dst, src); + } + } +} + +impl Reim4Mat1ColProd for FFT64Avx { + #[inline(always)] + fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) { + unsafe { + reim4_vec_mat1col_product_avx(nrows, dst, u, v); + } + } +} + +impl Reim4Mat2ColsProd for FFT64Avx { + #[inline(always)] + fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) { + unsafe { + reim4_vec_mat2cols_product_avx(nrows, dst, u, v); + } + } +} + +impl Reim4Mat2Cols2ndColProd for FFT64Avx { + #[inline(always)] + fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) { + unsafe { + reim4_vec_mat2cols_2ndcol_product_avx(nrows, dst, u, v); + } + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/reim/conversion.rs b/poulpy-backend/src/cpu_fft64_avx/reim/conversion.rs new file mode 100644 index 0000000..fbab715 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/reim/conversion.rs @@ -0,0 +1,271 @@ +/// # Correctness +/// Ensured for inputs absolute value bounded by 2^50-1 +/// # Safety +/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "fma")] +pub fn reim_from_znx_i64_bnd50_fma(res: &mut [f64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()) + } + + let n: usize = res.len(); + + unsafe { + use std::arch::x86_64::{ + __m256d, __m256i, _mm256_add_epi64, _mm256_castsi256_pd, _mm256_loadu_si256, _mm256_or_pd, _mm256_set1_epi64x, + _mm256_set1_pd, _mm256_storeu_pd, _mm256_sub_pd, + }; + + let expo: f64 = (1i64 << 52) as f64; + let add_cst: i64 = 1i64 << 51; + let sub_cst: f64 = (3i64 << 51) as f64; + + let expo_256: __m256d = _mm256_set1_pd(expo); + let add_cst_256: __m256i = _mm256_set1_epi64x(add_cst); + let sub_cst_256: __m256d = _mm256_set1_pd(sub_cst); + + let mut res_ptr: *mut f64 = res.as_mut_ptr(); + let mut a_ptr: *const __m256i = a.as_ptr() as *const __m256i; + + let span: usize = n >> 2; + + for _ in 0..span { + let mut ai64_256: __m256i = _mm256_loadu_si256(a_ptr); + + ai64_256 = _mm256_add_epi64(ai64_256, add_cst_256); + + let mut af64_256: __m256d = _mm256_castsi256_pd(ai64_256); + af64_256 = _mm256_or_pd(af64_256, expo_256); + af64_256 = _mm256_sub_pd(af64_256, sub_cst_256); + + _mm256_storeu_pd(res_ptr, af64_256); + + res_ptr = res_ptr.add(4); + a_ptr = a_ptr.add(1); + } + + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::fft64::reim::reim_from_znx_i64_ref; + reim_from_znx_i64_ref(&mut res[span << 2..], &a[span << 2..]) + } + } +} + +/// # Correctness +/// Only ensured for inputs absoluate value bounded by 2^63-1 +/// # Safety +/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`); +#[allow(dead_code)] +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_to_znx_i64_bnd63_avx2_fma(res: &mut [i64], divisor: f64, a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()) + } + + let sign_mask: u64 = 0x8000000000000000u64; + let expo_mask: u64 = 0x7FF0000000000000u64; + let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask; + let mantissa_msb: u64 = 0x0010000000000000u64; + let divi_bits: f64 = divisor * (1i64 << 52) as f64; + let offset: f64 = divisor / 2.; + + unsafe { + use std::arch::x86_64::{ + __m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd, + _mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64, + _mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256, + }; + + let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64)); + let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64); + let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64); + let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64); + let offset_256 = _mm256_set1_pd(offset); + let divi_bits_256 = _mm256_castpd_si256(_mm256_set1_pd(divi_bits)); + + let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut a_ptr: *const f64 = a.as_ptr(); + + let span: usize = res.len() >> 2; + + for _ in 0..span { + // read the next value + use std::arch::x86_64::_mm256_storeu_si256; + let mut a: __m256d = _mm256_loadu_pd(a_ptr); + + // a += sign(a) * m/2 + let asign: __m256d = _mm256_and_pd(a, sign_mask_256); + a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256)); + + // sign: either 0 or -1 + let mut sign_mask: __m256i = _mm256_castpd_si256(asign); + sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63)); + + // compute the exponents + let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256); + let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256); + let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp); + a0lsh = _mm256_srli_epi64(a0lsh, 52); + a0rsh = _mm256_srli_epi64(a0rsh, 52); + + // compute the new mantissa + let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256); + a0pos = _mm256_or_si256(a0pos, mantissa_msb_256); + a0lsh = _mm256_sllv_epi64(a0pos, a0lsh); + a0rsh = _mm256_srlv_epi64(a0pos, a0rsh); + let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh); + + // negate if the sign was negative + out = _mm256_xor_si256(out, sign_mask); + out = _mm256_sub_epi64(out, sign_mask); + + // stores + _mm256_storeu_si256(res_ptr, out); + + res_ptr = res_ptr.add(1); + a_ptr = a_ptr.add(4); + } + + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref; + reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..]) + } + } +} + +/// # Correctness +/// Only ensured for inputs absoluate value bounded by 2^63-1 +/// # Safety +/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_to_znx_i64_inplace_bnd63_avx2_fma(res: &mut [f64], divisor: f64) { + let sign_mask: u64 = 0x8000000000000000u64; + let expo_mask: u64 = 0x7FF0000000000000u64; + let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask; + let mantissa_msb: u64 = 0x0010000000000000u64; + let divi_bits: f64 = divisor * (1i64 << 52) as f64; + let offset: f64 = divisor / 2.; + + unsafe { + use std::arch::x86_64::{ + __m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd, + _mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64, + _mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256, + }; + + use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_inplace_ref; + + let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64)); + let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64); + let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64); + let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64); + let offset_256: __m256d = _mm256_set1_pd(offset); + let divi_bits_256: __m256i = _mm256_castpd_si256(_mm256_set1_pd(divi_bits)); + + let mut res_ptr_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut res_ptr_1xf64: *mut f64 = res.as_mut_ptr(); + + let span: usize = res.len() >> 2; + + for _ in 0..span { + // read the next value + use std::arch::x86_64::_mm256_storeu_si256; + let mut a: __m256d = _mm256_loadu_pd(res_ptr_1xf64); + + // a += sign(a) * m/2 + let asign: __m256d = _mm256_and_pd(a, sign_mask_256); + a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256)); + + // sign: either 0 or -1 + let mut sign_mask: __m256i = _mm256_castpd_si256(asign); + sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63)); + + // compute the exponents + let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256); + let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256); + let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp); + a0lsh = _mm256_srli_epi64(a0lsh, 52); + a0rsh = _mm256_srli_epi64(a0rsh, 52); + + // compute the new mantissa + let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256); + a0pos = _mm256_or_si256(a0pos, mantissa_msb_256); + a0lsh = _mm256_sllv_epi64(a0pos, a0lsh); + a0rsh = _mm256_srlv_epi64(a0pos, a0rsh); + let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh); + + // negate if the sign was negative + out = _mm256_xor_si256(out, sign_mask); + out = _mm256_sub_epi64(out, sign_mask); + + // stores + _mm256_storeu_si256(res_ptr_4xi64, out); + + res_ptr_4xi64 = res_ptr_4xi64.add(1); + res_ptr_1xf64 = res_ptr_1xf64.add(4); + } + + if !res.len().is_multiple_of(4) { + reim_to_znx_i64_inplace_ref(&mut res[span << 2..], divisor) + } + } + println!(); +} + +/// # Correctness +/// Only ensured for inputs absoluate value bounded by 2^50-1 +/// # Safety +/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "fma")] +#[allow(dead_code)] +pub fn reim_to_znx_i64_avx2_bnd50_fma(res: &mut [i64], divisor: f64, a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()) + } + + unsafe { + use std::arch::x86_64::{ + __m256d, __m256i, _mm256_add_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_loadu_pd, _mm256_set1_epi64x, + _mm256_set1_pd, _mm256_storeu_si256, _mm256_sub_epi64, + }; + + let mantissa_mask: u64 = 0x000FFFFFFFFFFFFFu64; + let sub_cst: i64 = 1i64 << 51; + let add_cst: f64 = divisor * (3i64 << 51) as f64; + + let sub_cst_4: __m256i = _mm256_set1_epi64x(sub_cst); + let add_cst_4: std::arch::x86_64::__m256d = _mm256_set1_pd(add_cst); + let mantissa_mask_4: __m256i = _mm256_set1_epi64x(mantissa_mask as i64); + + let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut a_ptr = a.as_ptr(); + + let span: usize = res.len() >> 2; + + for _ in 0..span { + // read the next value + let mut a: __m256d = _mm256_loadu_pd(a_ptr); + a = _mm256_add_pd(a, add_cst_4); + let mut ai: __m256i = _mm256_castpd_si256(a); + ai = _mm256_and_si256(ai, mantissa_mask_4); + ai = _mm256_sub_epi64(ai, sub_cst_4); + // store the next value + _mm256_storeu_si256(res_ptr, ai); + + res_ptr = res_ptr.add(1); + a_ptr = a_ptr.add(4); + } + + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref; + reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..]) + } + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/reim/fft16_avx2_fma.s b/poulpy-backend/src/cpu_fft64_avx/reim/fft16_avx2_fma.s new file mode 100644 index 0000000..7cab12f --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/reim/fft16_avx2_fma.s @@ -0,0 +1,162 @@ +# ---------------------------------------------------------------------- +# This kernel is a direct port of the FFT16 routine from spqlios-arithmetic +# (https://github.com/tfhe/spqlios-arithmetic) +# ---------------------------------------------------------------------- +# + +.text +.globl fft16_avx2_fma_asm +.hidden fft16_avx2_fma_asm +.p2align 4, 0x90 +.type fft16_avx2_fma_asm,@function +fft16_avx2_fma_asm: +.att_syntax prefix + +# SysV args: %rdi = re*, %rsi = im*, %rdx = omg* +# stage 0: load inputs +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +# stage 1 +vmovupd (%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # omar +vmulpd %ymm6,%ymm13,%ymm8 +vmulpd %ymm7,%ymm13,%ymm9 +vmulpd %ymm2,%ymm13,%ymm10 +vmulpd %ymm3,%ymm13,%ymm11 +vfmsub231pd %ymm2,%ymm12,%ymm8 +vfmsub231pd %ymm3,%ymm12,%ymm9 +vfmadd231pd %ymm6,%ymm12,%ymm10 +vfmadd231pd %ymm7,%ymm12,%ymm11 +vsubpd %ymm8,%ymm0,%ymm2 +vsubpd %ymm9,%ymm1,%ymm3 +vsubpd %ymm10,%ymm4,%ymm6 +vsubpd %ymm11,%ymm5,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm4,%ymm4 +vaddpd %ymm11,%ymm5,%ymm5 + +# stage 2 +vmovupd 16(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # omar +vmulpd %ymm5,%ymm13,%ymm8 +vmulpd %ymm7,%ymm12,%ymm9 +vmulpd %ymm1,%ymm13,%ymm10 +vmulpd %ymm3,%ymm12,%ymm11 +vfmsub231pd %ymm1,%ymm12,%ymm8 +vfmadd231pd %ymm3,%ymm13,%ymm9 +vfmadd231pd %ymm5,%ymm12,%ymm10 +vfmsub231pd %ymm7,%ymm13,%ymm11 +vsubpd %ymm8,%ymm0,%ymm1 +vaddpd %ymm9,%ymm2,%ymm3 +vsubpd %ymm10,%ymm4,%ymm5 +vaddpd %ymm11,%ymm6,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm2,%ymm2 +vaddpd %ymm10,%ymm4,%ymm4 +vsubpd %ymm11,%ymm6,%ymm6 + +# stage 3 +vmovupd 0x20(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # omar + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 + +vmulpd %ymm10,%ymm13,%ymm4 +vmulpd %ymm11,%ymm12,%ymm5 +vmulpd %ymm8,%ymm13,%ymm6 +vmulpd %ymm9,%ymm12,%ymm7 +vfmsub231pd %ymm8,%ymm12,%ymm4 +vfmadd231pd %ymm9,%ymm13,%ymm5 +vfmadd231pd %ymm10,%ymm12,%ymm6 +vfmsub231pd %ymm11,%ymm13,%ymm7 +vsubpd %ymm4,%ymm0,%ymm8 +vaddpd %ymm5,%ymm1,%ymm9 +vsubpd %ymm6,%ymm2,%ymm10 +vaddpd %ymm7,%ymm3,%ymm11 +vaddpd %ymm4,%ymm0,%ymm0 +vsubpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vsubpd %ymm7,%ymm3,%ymm3 + +# stage 4 +vmovupd 0x40(%rdx),%ymm12 +vmovupd 0x60(%rdx),%ymm13 + +vunpckhpd %ymm1,%ymm0,%ymm4 +vunpckhpd %ymm3,%ymm2,%ymm6 +vunpckhpd %ymm9,%ymm8,%ymm5 +vunpckhpd %ymm11,%ymm10,%ymm7 +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +vmulpd %ymm6,%ymm13,%ymm8 +vmulpd %ymm7,%ymm12,%ymm9 +vmulpd %ymm4,%ymm13,%ymm10 +vmulpd %ymm5,%ymm12,%ymm11 +vfmsub231pd %ymm4,%ymm12,%ymm8 +vfmadd231pd %ymm5,%ymm13,%ymm9 +vfmadd231pd %ymm6,%ymm12,%ymm10 +vfmsub231pd %ymm7,%ymm13,%ymm11 +vsubpd %ymm8,%ymm0,%ymm4 +vaddpd %ymm9,%ymm1,%ymm5 +vsubpd %ymm10,%ymm2,%ymm6 +vaddpd %ymm11,%ymm3,%ymm7 +vaddpd %ymm8,%ymm0,%ymm0 +vsubpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vsubpd %ymm11,%ymm3,%ymm3 + +vunpckhpd %ymm7,%ymm3,%ymm11 +vunpckhpd %ymm5,%ymm1,%ymm9 +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 +vunpckhpd %ymm4,%ymm0,%ymm1 +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +# stores +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret + +.size fft16_avx2_fma_asm, .-fft16_avx2_fma_asm +.section .note.GNU-stack,"",@progbits diff --git a/poulpy-backend/src/cpu_fft64_avx/reim/fft_avx2_fma.rs b/poulpy-backend/src/cpu_fft64_avx/reim/fft_avx2_fma.rs new file mode 100644 index 0000000..932ff3d --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/reim/fft_avx2_fma.rs @@ -0,0 +1,278 @@ +use std::arch::x86_64::{ + __m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, + _mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd, +}; + +use crate::cpu_fft64_avx::reim::{as_arr, as_arr_mut}; + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub(crate) fn fft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) { + if m < 16 { + use poulpy_hal::reference::fft64::reim::fft_ref; + + fft_ref(m, omg, data); + return; + } + + assert!(data.len() == 2 * m); + let (re, im) = data.split_at_mut(m); + + if m == 16 { + fft16_avx2_fma( + as_arr_mut::<16, f64>(re), + as_arr_mut::<16, f64>(im), + as_arr::<16, f64>(omg), + ) + } else if m <= 2048 { + fft_bfs_16_avx2_fma(m, re, im, omg, 0); + } else { + fft_rec_16_avx2_fma(m, re, im, omg, 0); + } +} + +unsafe extern "sysv64" { + unsafe fn fft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64); +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +fn fft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) { + unsafe { + fft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr()); + } +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +fn fft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize { + if m <= 2048 { + return fft_bfs_16_avx2_fma(m, re, im, omg, pos); + }; + + let h: usize = m >> 1; + twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..])); + pos += 2; + pos = fft_rec_16_avx2_fma(h, re, im, omg, pos); + pos = fft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos); + pos +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +fn fft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize { + let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize; + let mut mm: usize = m; + + if !log_m.is_multiple_of(2) { + let h: usize = mm >> 1; + twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..])); + pos += 2; + mm = h + } + + while mm > 16 { + let h: usize = mm >> 2; + for off in (0..m).step_by(mm) { + bitwiddle_fft_avx2_fma( + h, + &mut re[off..], + &mut im[off..], + as_arr::<4, f64>(&omg[pos..]), + ); + + pos += 4; + } + mm = h + } + + for off in (0..m).step_by(16) { + fft16_avx2_fma( + as_arr_mut::<16, f64>(&mut re[off..]), + as_arr_mut::<16, f64>(&mut im[off..]), + as_arr::<16, f64>(&omg[pos..]), + ); + + pos += 16; + } + + pos +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +fn twiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) { + unsafe { + let omx: __m128d = _mm_load_pd(omg.as_ptr()); + let omra: __m256d = _mm256_set_m128d(omx, omx); + let omi: __m256d = _mm256_unpackhi_pd(omra, omra); + let omr: __m256d = _mm256_unpacklo_pd(omra, omra); + let mut r0: *mut f64 = re.as_mut_ptr(); + let mut r1: *mut f64 = re.as_mut_ptr().add(h); + let mut i0: *mut f64 = im.as_mut_ptr(); + let mut i1: *mut f64 = im.as_mut_ptr().add(h); + + for _ in (0..h).step_by(4) { + let mut ur0: __m256d = _mm256_loadu_pd(r0); + let mut ur1: __m256d = _mm256_loadu_pd(r1); + let mut ui0: __m256d = _mm256_loadu_pd(i0); + let mut ui1: __m256d = _mm256_loadu_pd(i1); + let mut tra: __m256d = _mm256_mul_pd(omi, ui1); + let mut tia: __m256d = _mm256_mul_pd(omi, ur1); + + tra = _mm256_fmsub_pd(omr, ur1, tra); + tia = _mm256_fmadd_pd(omr, ui1, tia); + ur1 = _mm256_sub_pd(ur0, tra); + ui1 = _mm256_sub_pd(ui0, tia); + ur0 = _mm256_add_pd(ur0, tra); + ui0 = _mm256_add_pd(ui0, tia); + + _mm256_storeu_pd(r0, ur0); + _mm256_storeu_pd(r1, ur1); + _mm256_storeu_pd(i0, ui0); + _mm256_storeu_pd(i1, ui1); + + r0 = r0.add(4); + r1 = r1.add(4); + i0 = i0.add(4); + i1 = i1.add(4); + } + } +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +fn bitwiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) { + unsafe { + let mut r0: *mut f64 = re.as_mut_ptr(); + let mut r1: *mut f64 = re.as_mut_ptr().add(h); + let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h); + let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h); + let mut i0: *mut f64 = im.as_mut_ptr(); + let mut i1: *mut f64 = im.as_mut_ptr().add(h); + let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h); + let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h); + let om0: __m256d = _mm256_loadu_pd(omg.as_ptr()); + let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11); + let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00); + let omai: __m256d = _mm256_unpackhi_pd(oma, oma); + let omar: __m256d = _mm256_unpacklo_pd(oma, oma); + let ombi: __m256d = _mm256_unpackhi_pd(omb, omb); + let ombr: __m256d = _mm256_unpacklo_pd(omb, omb); + for _ in (0..h).step_by(4) { + let mut ur0: __m256d = _mm256_loadu_pd(r0); + let mut ur1: __m256d = _mm256_loadu_pd(r1); + let mut ur2: __m256d = _mm256_loadu_pd(r2); + let mut ur3: __m256d = _mm256_loadu_pd(r3); + let mut ui0: __m256d = _mm256_loadu_pd(i0); + let mut ui1: __m256d = _mm256_loadu_pd(i1); + let mut ui2: __m256d = _mm256_loadu_pd(i2); + let mut ui3: __m256d = _mm256_loadu_pd(i3); + + let mut tra: __m256d = _mm256_mul_pd(omai, ui2); + let mut trb: __m256d = _mm256_mul_pd(omai, ui3); + let mut tia: __m256d = _mm256_mul_pd(omai, ur2); + let mut tib: __m256d = _mm256_mul_pd(omai, ur3); + tra = _mm256_fmsub_pd(omar, ur2, tra); + trb = _mm256_fmsub_pd(omar, ur3, trb); + tia = _mm256_fmadd_pd(omar, ui2, tia); + tib = _mm256_fmadd_pd(omar, ui3, tib); + ur2 = _mm256_sub_pd(ur0, tra); + ur3 = _mm256_sub_pd(ur1, trb); + ui2 = _mm256_sub_pd(ui0, tia); + ui3 = _mm256_sub_pd(ui1, tib); + ur0 = _mm256_add_pd(ur0, tra); + ur1 = _mm256_add_pd(ur1, trb); + ui0 = _mm256_add_pd(ui0, tia); + ui1 = _mm256_add_pd(ui1, tib); + + tra = _mm256_mul_pd(ombi, ui1); + trb = _mm256_mul_pd(ombr, ui3); + tia = _mm256_mul_pd(ombi, ur1); + tib = _mm256_mul_pd(ombr, ur3); + tra = _mm256_fmsub_pd(ombr, ur1, tra); + trb = _mm256_fmadd_pd(ombi, ur3, trb); + tia = _mm256_fmadd_pd(ombr, ui1, tia); + tib = _mm256_fmsub_pd(ombi, ui3, tib); + ur1 = _mm256_sub_pd(ur0, tra); + ur3 = _mm256_add_pd(ur2, trb); + ui1 = _mm256_sub_pd(ui0, tia); + ui3 = _mm256_add_pd(ui2, tib); + ur0 = _mm256_add_pd(ur0, tra); + ur2 = _mm256_sub_pd(ur2, trb); + ui0 = _mm256_add_pd(ui0, tia); + ui2 = _mm256_sub_pd(ui2, tib); + + _mm256_storeu_pd(r0, ur0); + _mm256_storeu_pd(r1, ur1); + _mm256_storeu_pd(r2, ur2); + _mm256_storeu_pd(r3, ur3); + _mm256_storeu_pd(i0, ui0); + _mm256_storeu_pd(i1, ui1); + _mm256_storeu_pd(i2, ui2); + _mm256_storeu_pd(i3, ui3); + + r0 = r0.add(4); + r1 = r1.add(4); + r2 = r2.add(4); + r3 = r3.add(4); + i0 = i0.add(4); + i1 = i1.add(4); + i2 = i2.add(4); + i3 = i3.add(4); + } + } +} + +#[test] +fn test_fft_avx2_fma() { + use super::*; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + #[target_feature(enable = "avx2,fma")] + fn internal(log_m: usize) { + use poulpy_hal::reference::fft64::reim::ReimFFTRef; + + let m = 1 << log_m; + + let table: ReimFFTTable = ReimFFTTable::::new(m); + + let mut values_0: Vec = vec![0f64; m << 1]; + let scale: f64 = 1.0f64 / m as f64; + values_0 + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + + let mut values_1: Vec = vec![0f64; m << 1]; + values_1 + .iter_mut() + .zip(values_0.iter()) + .for_each(|(y, x)| *y = *x); + + ReimFFTAvx::reim_dft_execute(&table, &mut values_0); + ReimFFTRef::reim_dft_execute(&table, &mut values_1); + + let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64); + + for i in 0..m * 2 { + let diff: f64 = (values_0[i] - values_1[i]).abs(); + assert!( + diff <= max_diff, + "{} -> {}-{} = {}", + i, + values_0[i], + values_1[i], + diff + ) + } + } + + if std::is_x86_feature_detected!("avx2") { + for log_m in 0..16 { + unsafe { internal(log_m) } + } + } else { + eprintln!("skipping: CPU lacks avx2"); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs b/poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs new file mode 100644 index 0000000..3f4a9c9 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs @@ -0,0 +1,350 @@ +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_add_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + assert_eq!(b.len(), res.len()); + } + + use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd}; + + let span: usize = res.len() >> 2; + + unsafe { + let mut rr: *mut f64 = res.as_mut_ptr(); + let mut aa: *const f64 = a.as_ptr(); + let mut bb: *const f64 = b.as_ptr(); + + for _ in 0..span { + let a_256: __m256d = _mm256_loadu_pd(aa); + let b_256: __m256d = _mm256_loadu_pd(bb); + _mm256_storeu_pd(rr, _mm256_add_pd(a_256, b_256)); + rr = rr.add(4); + aa = aa.add(4); + bb = bb.add(4); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_add_inplace_avx2_fma(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd}; + + let span: usize = res.len() >> 2; + + unsafe { + let mut rr: *mut f64 = res.as_mut_ptr(); + let mut aa: *const f64 = a.as_ptr(); + + for _ in 0..span { + let a_256: __m256d = _mm256_loadu_pd(aa); + let r_256: __m256d = _mm256_loadu_pd(rr); + _mm256_storeu_pd(rr, _mm256_add_pd(r_256, a_256)); + rr = rr.add(4); + aa = aa.add(4); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_sub_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + assert_eq!(b.len(), res.len()); + } + + use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd}; + + let span: usize = res.len() >> 2; + + unsafe { + let mut rr: *mut f64 = res.as_mut_ptr(); + let mut aa: *const f64 = a.as_ptr(); + let mut bb: *const f64 = b.as_ptr(); + + for _ in 0..span { + let a_256: __m256d = _mm256_loadu_pd(aa); + let b_256: __m256d = _mm256_loadu_pd(bb); + _mm256_storeu_pd(rr, _mm256_sub_pd(a_256, b_256)); + rr = rr.add(4); + aa = aa.add(4); + bb = bb.add(4); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd}; + + let span: usize = res.len() >> 2; + + unsafe { + let mut rr: *mut f64 = res.as_mut_ptr(); + let mut aa: *const f64 = a.as_ptr(); + + for _ in 0..span { + let a_256: __m256d = _mm256_loadu_pd(aa); + let r_256: __m256d = _mm256_loadu_pd(rr); + _mm256_storeu_pd(rr, _mm256_sub_pd(r_256, a_256)); + rr = rr.add(4); + aa = aa.add(4); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_sub_ba_inplace_avx2_fma(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd}; + + let span: usize = res.len() >> 2; + + unsafe { + let mut rr: *mut f64 = res.as_mut_ptr(); + let mut aa: *const f64 = a.as_ptr(); + + for _ in 0..span { + let a_256: __m256d = _mm256_loadu_pd(aa); + let r_256: __m256d = _mm256_loadu_pd(rr); + _mm256_storeu_pd(rr, _mm256_sub_pd(a_256, r_256)); + rr = rr.add(4); + aa = aa.add(4); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_negate_avx2_fma(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd}; + + let span: usize = res.len() >> 2; + + unsafe { + use std::arch::x86_64::_mm256_set1_pd; + + let mut rr: *mut f64 = res.as_mut_ptr(); + let mut aa: *const f64 = a.as_ptr(); + + let neg0: __m256d = _mm256_set1_pd(-0.0); + + for _ in 0..span { + let a_256: __m256d = _mm256_loadu_pd(aa); + _mm256_storeu_pd(rr, _mm256_xor_pd(a_256, neg0)); + rr = rr.add(4); + aa = aa.add(4); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_negate_inplace_avx2_fma(res: &mut [f64]) { + use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd}; + + let span: usize = res.len() >> 2; + + unsafe { + use std::arch::x86_64::_mm256_set1_pd; + + let mut rr: *mut f64 = res.as_mut_ptr(); + let neg0: __m256d = _mm256_set1_pd(-0.0); + + for _ in 0..span { + let r_256: __m256d = _mm256_loadu_pd(rr); + _mm256_storeu_pd(rr, _mm256_xor_pd(r_256, neg0)); + rr = rr.add(4); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_addmul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + assert_eq!(b.len(), res.len()); + } + + let m: usize = res.len() >> 1; + + let (rr, ri) = res.split_at_mut(m); + let (ar, ai) = a.split_at(m); + let (br, bi) = b.split_at(m); + + unsafe { + let mut rr_ptr: *mut f64 = rr.as_mut_ptr(); + let mut ri_ptr: *mut f64 = ri.as_mut_ptr(); + let mut ar_ptr: *const f64 = ar.as_ptr(); + let mut ai_ptr: *const f64 = ai.as_ptr(); + let mut br_ptr: *const f64 = br.as_ptr(); + let mut bi_ptr: *const f64 = bi.as_ptr(); + + use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_storeu_pd}; + + for _ in 0..(m >> 2) { + let mut rr: __m256d = _mm256_loadu_pd(rr_ptr); + let mut ri: __m256d = _mm256_loadu_pd(ri_ptr); + let ar: __m256d = _mm256_loadu_pd(ar_ptr); + let ai: __m256d = _mm256_loadu_pd(ai_ptr); + let br: __m256d = _mm256_loadu_pd(br_ptr); + let bi: __m256d = _mm256_loadu_pd(bi_ptr); + + rr = _mm256_fmsub_pd(ai, bi, rr); + rr = _mm256_fmsub_pd(ar, br, rr); + ri = _mm256_fmadd_pd(ar, bi, ri); + ri = _mm256_fmadd_pd(ai, br, ri); + + _mm256_storeu_pd(rr_ptr, rr); + _mm256_storeu_pd(ri_ptr, ri); + + rr_ptr = rr_ptr.add(4); + ri_ptr = ri_ptr.add(4); + ar_ptr = ar_ptr.add(4); + ai_ptr = ai_ptr.add(4); + br_ptr = br_ptr.add(4); + bi_ptr = bi_ptr.add(4); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_mul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + assert_eq!(b.len(), res.len()); + } + + let m: usize = res.len() >> 1; + + let (rr, ri) = res.split_at_mut(m); + let (ar, ai) = a.split_at(m); + let (br, bi) = b.split_at(m); + + unsafe { + let mut rr_ptr: *mut f64 = rr.as_mut_ptr(); + let mut ri_ptr: *mut f64 = ri.as_mut_ptr(); + let mut ar_ptr: *const f64 = ar.as_ptr(); + let mut ai_ptr: *const f64 = ai.as_ptr(); + let mut br_ptr: *const f64 = br.as_ptr(); + let mut bi_ptr: *const f64 = bi.as_ptr(); + + use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd}; + + for _ in 0..(m >> 2) { + let ar: __m256d = _mm256_loadu_pd(ar_ptr); + let ai: __m256d = _mm256_loadu_pd(ai_ptr); + let br: __m256d = _mm256_loadu_pd(br_ptr); + let bi: __m256d = _mm256_loadu_pd(bi_ptr); + + let t1: __m256d = _mm256_mul_pd(ai, bi); + let t2: __m256d = _mm256_mul_pd(ar, bi); + + let rr: __m256d = _mm256_fmsub_pd(ar, br, t1); + let ri: __m256d = _mm256_fmadd_pd(ai, br, t2); + + _mm256_storeu_pd(rr_ptr, rr); + _mm256_storeu_pd(ri_ptr, ri); + + rr_ptr = rr_ptr.add(4); + ri_ptr = ri_ptr.add(4); + ar_ptr = ar_ptr.add(4); + ai_ptr = ai_ptr.add(4); + br_ptr = br_ptr.add(4); + bi_ptr = bi_ptr.add(4); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +pub fn reim_mul_inplace_avx2_fma(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + let m: usize = res.len() >> 1; + + let (rr, ri) = res.split_at_mut(m); + let (ar, ai) = a.split_at(m); + + unsafe { + let mut rr_ptr: *mut f64 = rr.as_mut_ptr(); + let mut ri_ptr: *mut f64 = ri.as_mut_ptr(); + let mut ar_ptr: *const f64 = ar.as_ptr(); + let mut ai_ptr: *const f64 = ai.as_ptr(); + + use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd}; + + for _ in 0..(m >> 2) { + let ar: __m256d = _mm256_loadu_pd(ar_ptr); + let ai: __m256d = _mm256_loadu_pd(ai_ptr); + let br: __m256d = _mm256_loadu_pd(rr_ptr); + let bi: __m256d = _mm256_loadu_pd(ri_ptr); + + let t1: __m256d = _mm256_mul_pd(ai, bi); + let t2: __m256d = _mm256_mul_pd(ar, bi); + + let rr = _mm256_fmsub_pd(ar, br, t1); + let ri = _mm256_fmadd_pd(ai, br, t2); + + _mm256_storeu_pd(rr_ptr, rr); + _mm256_storeu_pd(ri_ptr, ri); + + rr_ptr = rr_ptr.add(4); + ri_ptr = ri_ptr.add(4); + ar_ptr = ar_ptr.add(4); + ai_ptr = ai_ptr.add(4); + } + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/reim/ifft16_avx2_fma.s b/poulpy-backend/src/cpu_fft64_avx/reim/ifft16_avx2_fma.s new file mode 100644 index 0000000..df344e1 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/reim/ifft16_avx2_fma.s @@ -0,0 +1,181 @@ +# ---------------------------------------------------------------------- +# This kernel is a direct port of the IFFT16 routine from spqlios-arithmetic +# (https://github.com/tfhe/spqlios-arithmetic) +# ---------------------------------------------------------------------- +# + +.text +.globl ifft16_avx2_fma_asm +.hidden ifft16_avx2_fma_asm +.p2align 4, 0x90 +.type ifft16_avx2_fma_asm,@function +ifft16_avx2_fma_asm: +.att_syntax prefix + +vmovupd (%rdi),%ymm0 # ra0 +vmovupd 0x20(%rdi),%ymm1 # ra4 +vmovupd 0x40(%rdi),%ymm2 # ra8 +vmovupd 0x60(%rdi),%ymm3 # ra12 +vmovupd (%rsi),%ymm4 # ia0 +vmovupd 0x20(%rsi),%ymm5 # ia4 +vmovupd 0x40(%rsi),%ymm6 # ia8 +vmovupd 0x60(%rsi),%ymm7 # ia12 + +1: +vmovupd 0x00(%rdx),%ymm12 +vmovupd 0x20(%rdx),%ymm13 + +vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw) +vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw) +vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw) +vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw) +vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw) +vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw) +vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw) +vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw) + +vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4) +vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6) +vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5) +vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7) +vunpcklpd %ymm1,%ymm0,%ymm0 +vunpcklpd %ymm3,%ymm2,%ymm2 +vunpcklpd %ymm9,%ymm8,%ymm1 +vunpcklpd %ymm11,%ymm10,%ymm3 + +# invctwiddle Re:(ymm0,ymm4) and Im:(ymm2,ymm6) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm5) and Im:(ymm3,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm4,%ymm0,%ymm8 # retw +vsubpd %ymm5,%ymm1,%ymm9 # reitw +vsubpd %ymm6,%ymm2,%ymm10 # imtw +vsubpd %ymm7,%ymm3,%ymm11 # imitw +vaddpd %ymm4,%ymm0,%ymm0 +vaddpd %ymm5,%ymm1,%ymm1 +vaddpd %ymm6,%ymm2,%ymm2 +vaddpd %ymm7,%ymm3,%ymm3 +# multiply 8,9,10,11 by 12,13, result to: 4,5,6,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm4 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm5 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm6 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1) +vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3) +vunpcklpd %ymm7,%ymm3,%ymm10 +vunpcklpd %ymm5,%ymm1,%ymm8 +vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9) +vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11) +vunpcklpd %ymm6,%ymm2,%ymm2 +vunpcklpd %ymm4,%ymm0,%ymm0 + +2: +vmovupd 0x40(%rdx),%ymm12 +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i' +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r' + +# invctwiddle Re:(ymm0,ymm8) and Im:(ymm2,ymm10) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm1,ymm9) and Im:(ymm3,ymm11) with omega=(ymm12,ymm13) +vsubpd %ymm8,%ymm0,%ymm4 # retw +vsubpd %ymm9,%ymm1,%ymm5 # reitw +vsubpd %ymm10,%ymm2,%ymm6 # imtw +vsubpd %ymm11,%ymm3,%ymm7 # imitw +vaddpd %ymm8,%ymm0,%ymm0 +vaddpd %ymm9,%ymm1,%ymm1 +vaddpd %ymm10,%ymm2,%ymm2 +vaddpd %ymm11,%ymm3,%ymm3 +# multiply 4,5,6,7 by 12,13, result to 8,9,10,11 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm6,%ymm13,%ymm8 # imtw.omai (tw) +vmulpd %ymm7,%ymm12,%ymm9 # imitw.omar (itw) +vmulpd %ymm4,%ymm13,%ymm10 # retw.omai (tw) +vmulpd %ymm5,%ymm12,%ymm11 # reitw.omar (itw) +vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw) +vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw) +vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw) +vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw) + +vperm2f128 $0x31,%ymm10,%ymm2,%ymm6 +vperm2f128 $0x31,%ymm11,%ymm3,%ymm7 +vperm2f128 $0x20,%ymm10,%ymm2,%ymm4 +vperm2f128 $0x20,%ymm11,%ymm3,%ymm5 +vperm2f128 $0x31,%ymm8,%ymm0,%ymm2 +vperm2f128 $0x31,%ymm9,%ymm1,%ymm3 +vperm2f128 $0x20,%ymm8,%ymm0,%ymm0 +vperm2f128 $0x20,%ymm9,%ymm1,%ymm1 + +3: +vmovupd 0x60(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm1) and Im:(ymm4,ymm5) with omega=(ymm12,ymm13) +# invcitwiddle Re:(ymm2,ymm3) and Im:(ymm6,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm1,%ymm0,%ymm8 # retw +vsubpd %ymm3,%ymm2,%ymm9 # reitw +vsubpd %ymm5,%ymm4,%ymm10 # imtw +vsubpd %ymm7,%ymm6,%ymm11 # imitw +vaddpd %ymm1,%ymm0,%ymm0 +vaddpd %ymm3,%ymm2,%ymm2 +vaddpd %ymm5,%ymm4,%ymm4 +vaddpd %ymm7,%ymm6,%ymm6 +# multiply 8,9,10,11 by 12,13, result to 1,3,5,7 +# twiddles use reom=ymm12, imom=ymm13 +# invtwiddles use reom=ymm13, imom=-ymm12 +vmulpd %ymm10,%ymm13,%ymm1 # imtw.omai (tw) +vmulpd %ymm11,%ymm12,%ymm3 # imitw.omar (itw) +vmulpd %ymm8,%ymm13,%ymm5 # retw.omai (tw) +vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw) +vfmsub231pd %ymm8,%ymm12,%ymm1 # rprod0 (tw) +vfmadd231pd %ymm9,%ymm13,%ymm3 # rprod4 (itw) +vfmadd231pd %ymm10,%ymm12,%ymm5 # iprod0 (tw) +vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw) + +4: +vmovupd 0x70(%rdx),%xmm12 +vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri +vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai +vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar + +# invctwiddle Re:(ymm0,ymm2) and Im:(ymm4,ymm6) with omega=(ymm12,ymm13) +# invctwiddle Re:(ymm1,ymm3) and Im:(ymm5,ymm7) with omega=(ymm12,ymm13) +vsubpd %ymm2,%ymm0,%ymm8 # retw1 +vsubpd %ymm3,%ymm1,%ymm9 # retw2 +vsubpd %ymm6,%ymm4,%ymm10 # imtw1 +vsubpd %ymm7,%ymm5,%ymm11 # imtw2 +vaddpd %ymm2,%ymm0,%ymm0 +vaddpd %ymm3,%ymm1,%ymm1 +vaddpd %ymm6,%ymm4,%ymm4 +vaddpd %ymm7,%ymm5,%ymm5 +# multiply 8,9,10,11 by 12,13, result to 2,3,6,7 +# twiddles use reom=ymm12, imom=ymm13 +vmulpd %ymm10,%ymm13,%ymm2 # imtw1.omai +vmulpd %ymm11,%ymm13,%ymm3 # imtw2.omai +vmulpd %ymm8,%ymm13,%ymm6 # retw1.omai +vmulpd %ymm9,%ymm13,%ymm7 # retw2.omai +vfmsub231pd %ymm8,%ymm12,%ymm2 # rprod0 +vfmsub231pd %ymm9,%ymm12,%ymm3 # rprod4 +vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 +vfmadd231pd %ymm11,%ymm12,%ymm7 # iprod4 + +5: +vmovupd %ymm0,(%rdi) # ra0 +vmovupd %ymm1,0x20(%rdi) # ra4 +vmovupd %ymm2,0x40(%rdi) # ra8 +vmovupd %ymm3,0x60(%rdi) # ra12 +vmovupd %ymm4,(%rsi) # ia0 +vmovupd %ymm5,0x20(%rsi) # ia4 +vmovupd %ymm6,0x40(%rsi) # ia8 +vmovupd %ymm7,0x60(%rsi) # ia12 +vzeroupper +ret + +.size ifft16_avx_fma, .-ifft16_avx_fma +.section .note.GNU-stack,"",@progbits \ No newline at end of file diff --git a/poulpy-backend/src/cpu_fft64_avx/reim/ifft_avx2_fma.rs b/poulpy-backend/src/cpu_fft64_avx/reim/ifft_avx2_fma.rs new file mode 100644 index 0000000..5d93ea1 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/reim/ifft_avx2_fma.rs @@ -0,0 +1,271 @@ +use std::arch::x86_64::{ + __m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, + _mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd, +}; + +use crate::cpu_fft64_avx::reim::{as_arr, as_arr_mut}; + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,fma")] +pub(crate) fn ifft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) { + if m < 16 { + use poulpy_hal::reference::fft64::reim::ifft_ref; + ifft_ref(m, omg, data); + return; + } + + assert!(data.len() == 2 * m); + let (re, im) = data.split_at_mut(m); + + if m == 16 { + ifft16_avx2_fma( + as_arr_mut::<16, f64>(re), + as_arr_mut::<16, f64>(im), + as_arr::<16, f64>(omg), + ) + } else if m <= 2048 { + ifft_bfs_16_avx2_fma(m, re, im, omg, 0); + } else { + ifft_rec_16_avx2_fma(m, re, im, omg, 0); + } +} + +unsafe extern "sysv64" { + unsafe fn ifft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64); +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +fn ifft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) { + unsafe { + ifft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr()); + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,fma")] +fn ifft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize { + if m <= 2048 { + return ifft_bfs_16_avx2_fma(m, re, im, omg, pos); + }; + let h: usize = m >> 1; + pos = ifft_rec_16_avx2_fma(h, re, im, omg, pos); + pos = ifft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos); + inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..])); + pos += 2; + pos +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,fma")] +fn ifft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize { + let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize; + + for off in (0..m).step_by(16) { + ifft16_avx2_fma( + as_arr_mut::<16, f64>(&mut re[off..]), + as_arr_mut::<16, f64>(&mut im[off..]), + as_arr::<16, f64>(&omg[pos..]), + ); + pos += 16; + } + + let mut h: usize = 16; + let m_half: usize = m >> 1; + + while h < m_half { + let mm: usize = h << 2; + for off in (0..m).step_by(mm) { + inv_bitwiddle_ifft_avx2_fma( + h, + &mut re[off..], + &mut im[off..], + as_arr::<4, f64>(&omg[pos..]), + ); + pos += 4; + } + h = mm; + } + + if !log_m.is_multiple_of(2) { + inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..])); + pos += 2; + } + + pos +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +fn inv_twiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) { + unsafe { + let omx: __m128d = _mm_load_pd(omg.as_ptr()); + let omra: __m256d = _mm256_set_m128d(omx, omx); + let omi: __m256d = _mm256_unpackhi_pd(omra, omra); + let omr: __m256d = _mm256_unpacklo_pd(omra, omra); + let mut r0: *mut f64 = re.as_mut_ptr(); + let mut r1: *mut f64 = re.as_mut_ptr().add(h); + let mut i0: *mut f64 = im.as_mut_ptr(); + let mut i1: *mut f64 = im.as_mut_ptr().add(h); + for _ in (0..h).step_by(4) { + let mut ur0: __m256d = _mm256_loadu_pd(r0); + let mut ur1: __m256d = _mm256_loadu_pd(r1); + let mut ui0: __m256d = _mm256_loadu_pd(i0); + let mut ui1: __m256d = _mm256_loadu_pd(i1); + let tra = _mm256_sub_pd(ur0, ur1); + let tia = _mm256_sub_pd(ui0, ui1); + ur0 = _mm256_add_pd(ur0, ur1); + ui0 = _mm256_add_pd(ui0, ui1); + ur1 = _mm256_mul_pd(omi, tia); + ui1 = _mm256_mul_pd(omi, tra); + ur1 = _mm256_fmsub_pd(omr, tra, ur1); + ui1 = _mm256_fmadd_pd(omr, tia, ui1); + _mm256_storeu_pd(r0, ur0); + _mm256_storeu_pd(r1, ur1); + _mm256_storeu_pd(i0, ui0); + _mm256_storeu_pd(i1, ui1); + + r0 = r0.add(4); + r1 = r1.add(4); + i0 = i0.add(4); + i1 = i1.add(4); + } + } +} + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2,fma")] +fn inv_bitwiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) { + unsafe { + let mut r0: *mut f64 = re.as_mut_ptr(); + let mut r1: *mut f64 = re.as_mut_ptr().add(h); + let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h); + let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h); + let mut i0: *mut f64 = im.as_mut_ptr(); + let mut i1: *mut f64 = im.as_mut_ptr().add(h); + let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h); + let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h); + let om0: __m256d = _mm256_loadu_pd(omg.as_ptr()); + let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11); + let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00); + let omai: __m256d = _mm256_unpackhi_pd(oma, oma); + let omar: __m256d = _mm256_unpacklo_pd(oma, oma); + let ombi: __m256d = _mm256_unpackhi_pd(omb, omb); + let ombr: __m256d = _mm256_unpacklo_pd(omb, omb); + for _ in (0..h).step_by(4) { + let mut ur0: __m256d = _mm256_loadu_pd(r0); + let mut ur1: __m256d = _mm256_loadu_pd(r1); + let mut ur2: __m256d = _mm256_loadu_pd(r2); + let mut ur3: __m256d = _mm256_loadu_pd(r3); + let mut ui0: __m256d = _mm256_loadu_pd(i0); + let mut ui1: __m256d = _mm256_loadu_pd(i1); + let mut ui2: __m256d = _mm256_loadu_pd(i2); + let mut ui3: __m256d = _mm256_loadu_pd(i3); + + let mut tra: __m256d = _mm256_sub_pd(ur0, ur1); + let mut trb: __m256d = _mm256_sub_pd(ur2, ur3); + let mut tia: __m256d = _mm256_sub_pd(ui0, ui1); + let mut tib: __m256d = _mm256_sub_pd(ui2, ui3); + ur0 = _mm256_add_pd(ur0, ur1); + ur2 = _mm256_add_pd(ur2, ur3); + ui0 = _mm256_add_pd(ui0, ui1); + ui2 = _mm256_add_pd(ui2, ui3); + ur1 = _mm256_mul_pd(omai, tia); + ur3 = _mm256_mul_pd(omar, tib); + ui1 = _mm256_mul_pd(omai, tra); + ui3 = _mm256_mul_pd(omar, trb); + ur1 = _mm256_fmsub_pd(omar, tra, ur1); + ur3 = _mm256_fmadd_pd(omai, trb, ur3); + ui1 = _mm256_fmadd_pd(omar, tia, ui1); + ui3 = _mm256_fmsub_pd(omai, tib, ui3); + + tra = _mm256_sub_pd(ur0, ur2); + trb = _mm256_sub_pd(ur1, ur3); + tia = _mm256_sub_pd(ui0, ui2); + tib = _mm256_sub_pd(ui1, ui3); + ur0 = _mm256_add_pd(ur0, ur2); + ur1 = _mm256_add_pd(ur1, ur3); + ui0 = _mm256_add_pd(ui0, ui2); + ui1 = _mm256_add_pd(ui1, ui3); + ur2 = _mm256_mul_pd(ombi, tia); + ur3 = _mm256_mul_pd(ombi, tib); + ui2 = _mm256_mul_pd(ombi, tra); + ui3 = _mm256_mul_pd(ombi, trb); + ur2 = _mm256_fmsub_pd(ombr, tra, ur2); + ur3 = _mm256_fmsub_pd(ombr, trb, ur3); + ui2 = _mm256_fmadd_pd(ombr, tia, ui2); + ui3 = _mm256_fmadd_pd(ombr, tib, ui3); + + _mm256_storeu_pd(r0, ur0); + _mm256_storeu_pd(r1, ur1); + _mm256_storeu_pd(r2, ur2); + _mm256_storeu_pd(r3, ur3); + _mm256_storeu_pd(i0, ui0); + _mm256_storeu_pd(i1, ui1); + _mm256_storeu_pd(i2, ui2); + _mm256_storeu_pd(i3, ui3); + + r0 = r0.add(4); + r1 = r1.add(4); + r2 = r2.add(4); + r3 = r3.add(4); + i0 = i0.add(4); + i1 = i1.add(4); + i2 = i2.add(4); + i3 = i3.add(4); + } + } +} + +#[test] +fn test_ifft_avx2_fma() { + use super::*; + + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + #[target_feature(enable = "avx2,fma")] + fn internal(log_m: usize) { + use poulpy_hal::reference::fft64::reim::ReimIFFTRef; + + let m: usize = 1 << log_m; + + let table: ReimIFFTTable = ReimIFFTTable::::new(m); + + let mut values_0: Vec = vec![0f64; m << 1]; + let scale: f64 = 1.0f64 / m as f64; + values_0 + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (i + 1) as f64 * scale); + + let mut values_1: Vec = vec![0f64; m << 1]; + values_1 + .iter_mut() + .zip(values_0.iter()) + .for_each(|(y, x)| *y = *x); + + ReimIFFTAvx::reim_dft_execute(&table, &mut values_0); + ReimIFFTRef::reim_dft_execute(&table, &mut values_1); + + let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64); + + for i in 0..m * 2 { + let diff: f64 = (values_0[i] - values_1[i]).abs(); + assert!( + diff <= max_diff, + "{} -> {}-{} = {}", + i, + values_0[i], + values_1[i], + diff + ) + } + } + + if std::is_x86_feature_detected!("avx2") { + for log_m in 0..16 { + unsafe { internal(log_m) } + } + } else { + eprintln!("skipping: CPU lacks avx2"); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/reim/mod.rs b/poulpy-backend/src/cpu_fft64_avx/reim/mod.rs new file mode 100644 index 0000000..cb57bd3 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/reim/mod.rs @@ -0,0 +1,72 @@ +// ---------------------------------------------------------------------- +// DISCLAIMER +// +// This module contains code that has been directly ported from the +// spqlios-arithmetic library +// (https://github.com/tfhe/spqlios-arithmetic), which is licensed +// under the Apache License, Version 2.0. +// +// The porting process from C to Rust was done with minimal changes +// in order to preserve the semantics and performance characteristics +// of the original implementation. +// +// Both Poulpy and spqlios-arithmetic are distributed under the terms +// of the Apache License, Version 2.0. See the LICENSE file for details. +// +// ---------------------------------------------------------------------- + +#![allow(bad_asm_style)] + +mod conversion; +mod fft_avx2_fma; +mod fft_vec_avx2_fma; +mod ifft_avx2_fma; + +use std::arch::global_asm; + +pub(crate) use conversion::*; +pub(crate) use fft_vec_avx2_fma::*; + +use poulpy_hal::reference::fft64::reim::{ReimDFTExecute, ReimFFTTable, ReimIFFTTable}; +use rand_distr::num_traits::{Float, FloatConst}; + +use crate::cpu_fft64_avx::reim::{fft_avx2_fma::fft_avx2_fma, ifft_avx2_fma::ifft_avx2_fma}; + +global_asm!( + include_str!("fft16_avx2_fma.s"), + include_str!("ifft16_avx2_fma.s") +); + +#[inline(always)] +pub(crate) fn as_arr(x: &[R]) -> &[R; SIZE] { + debug_assert!(x.len() >= SIZE); + unsafe { &*(x.as_ptr() as *const [R; SIZE]) } +} + +#[inline(always)] +pub(crate) fn as_arr_mut(x: &mut [R]) -> &mut [R; SIZE] { + debug_assert!(x.len() >= SIZE); + unsafe { &mut *(x.as_mut_ptr() as *mut [R; SIZE]) } +} + +pub struct ReimFFTAvx; + +impl ReimDFTExecute, f64> for ReimFFTAvx { + #[inline(always)] + fn reim_dft_execute(table: &ReimFFTTable, data: &mut [f64]) { + unsafe { + fft_avx2_fma(table.m(), table.omg(), data); + } + } +} + +pub struct ReimIFFTAvx; + +impl ReimDFTExecute, f64> for ReimIFFTAvx { + #[inline(always)] + fn reim_dft_execute(table: &ReimIFFTTable, data: &mut [f64]) { + unsafe { + ifft_avx2_fma(table.m(), table.omg(), data); + } + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/reim4/arithmetic_avx.rs b/poulpy-backend/src/cpu_fft64_avx/reim4/arithmetic_avx.rs new file mode 100644 index 0000000..2bfc371 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/reim4/arithmetic_avx.rs @@ -0,0 +1,264 @@ +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx")] +pub fn reim4_extract_1blk_from_reim_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + use core::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd}; + + unsafe { + let mut src_ptr: *const __m256d = src.as_ptr().add(blk << 2) as *const __m256d; // src + 4*blk + let mut dst_ptr: *mut __m256d = dst.as_mut_ptr() as *mut __m256d; + + let step: usize = m >> 2; + + // Each iteration copies 4 doubles; advance src by m doubles each row + for _ in 0..2 * rows { + let v: __m256d = _mm256_loadu_pd(src_ptr as *const f64); + _mm256_storeu_pd(dst_ptr as *mut f64, v); + dst_ptr = dst_ptr.add(1); + src_ptr = src_ptr.add(step); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,fma")] +pub fn reim4_save_1blk_to_reim_avx(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + use core::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd}; + unsafe { + let off: usize = blk * 4; + let src_ptr: *const f64 = src.as_ptr(); + + let s0: __m256d = _mm256_loadu_pd(src_ptr); + let s1: __m256d = _mm256_loadu_pd(src_ptr.add(4)); + + let d0_ptr: *mut f64 = dst.as_mut_ptr().add(off); + let d1_ptr: *mut f64 = d0_ptr.add(m); + + if OVERWRITE { + _mm256_storeu_pd(d0_ptr, s0); + _mm256_storeu_pd(d1_ptr, s1); + } else { + let d0: __m256d = _mm256_loadu_pd(d0_ptr); + let d1: __m256d = _mm256_loadu_pd(d1_ptr); + _mm256_storeu_pd(d0_ptr, _mm256_add_pd(d0, s0)); + _mm256_storeu_pd(d1_ptr, _mm256_add_pd(d1, s1)); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,fma")] +pub fn reim4_save_2blk_to_reim_avx( + m: usize, // + blk: usize, // block index + dst: &mut [f64], // + src: &[f64], // 16 doubles [re1(4), im1(4), re2(4), im2(4)] +) { + use core::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd}; + unsafe { + let off: usize = blk * 4; + let src_ptr: *const f64 = src.as_ptr(); + + let d0_ptr: *mut f64 = dst.as_mut_ptr().add(off); + let d1_ptr: *mut f64 = d0_ptr.add(m); + let d2_ptr: *mut f64 = d1_ptr.add(m); + let d3_ptr: *mut f64 = d2_ptr.add(m); + + let s0: __m256d = _mm256_loadu_pd(src_ptr); + let s1: __m256d = _mm256_loadu_pd(src_ptr.add(4)); + let s2: __m256d = _mm256_loadu_pd(src_ptr.add(8)); + let s3: __m256d = _mm256_loadu_pd(src_ptr.add(12)); + + if OVERWRITE { + _mm256_storeu_pd(d0_ptr, s0); + _mm256_storeu_pd(d1_ptr, s1); + _mm256_storeu_pd(d2_ptr, s2); + _mm256_storeu_pd(d3_ptr, s3); + } else { + let d0: __m256d = _mm256_loadu_pd(d0_ptr); + let d1: __m256d = _mm256_loadu_pd(d1_ptr); + let d2: __m256d = _mm256_loadu_pd(d2_ptr); + let d3: __m256d = _mm256_loadu_pd(d3_ptr); + _mm256_storeu_pd(d0_ptr, _mm256_add_pd(d0, s0)); + _mm256_storeu_pd(d1_ptr, _mm256_add_pd(d1, s1)); + _mm256_storeu_pd(d2_ptr, _mm256_add_pd(d2, s2)); + _mm256_storeu_pd(d3_ptr, _mm256_add_pd(d3, s3)); + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +pub fn reim4_vec_mat1col_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) { + use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd}; + + #[cfg(debug_assertions)] + { + assert!(dst.len() >= 8, "dst must have at least 8 doubles"); + assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles"); + assert!(v.len() >= nrows * 8, "v must be at least nrows * 8 doubles"); + } + + unsafe { + use std::arch::x86_64::{_mm256_add_pd, _mm256_sub_pd}; + + let mut re1: __m256d = _mm256_setzero_pd(); + let mut im1: __m256d = _mm256_setzero_pd(); + let mut re2: __m256d = _mm256_setzero_pd(); + let mut im2: __m256d = _mm256_setzero_pd(); + + let mut u_ptr: *const f64 = u.as_ptr(); + let mut v_ptr: *const f64 = v.as_ptr(); + + for _ in 0..nrows { + let ur: __m256d = _mm256_loadu_pd(u_ptr); + let ui: __m256d = _mm256_loadu_pd(u_ptr.add(4)); + let vr: __m256d = _mm256_loadu_pd(v_ptr); + let vi: __m256d = _mm256_loadu_pd(v_ptr.add(4)); + + // re1 = re1 + ur*vr; + re1 = _mm256_fmadd_pd(ur, vr, re1); + // im1 = im1 + ur*d; + im1 = _mm256_fmadd_pd(ur, vi, im1); + // re2 = re2 + ui*d; + re2 = _mm256_fmadd_pd(ui, vi, re2); + // im2 = im2 + ui*vr; + im2 = _mm256_fmadd_pd(ui, vr, im2); + + u_ptr = u_ptr.add(8); + v_ptr = v_ptr.add(8); + } + + // re1 - re2 + _mm256_storeu_pd(dst.as_mut_ptr(), _mm256_sub_pd(re1, re2)); + + // im1 + im2 + _mm256_storeu_pd(dst.as_mut_ptr().add(4), _mm256_add_pd(im1, im2)); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +pub fn reim4_vec_mat2cols_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) { + use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd}; + + #[cfg(debug_assertions)] + { + assert!( + dst.len() >= 8, + "dst must be at least 8 doubles but is {}", + dst.len() + ); + assert!( + u.len() >= nrows * 8, + "u must be at least nrows={} * 8 doubles but is {}", + nrows, + u.len() + ); + assert!( + v.len() >= nrows * 16, + "v must be at least nrows={} * 16 doubles but is {}", + nrows, + v.len() + ); + } + + unsafe { + let mut re1: __m256d = _mm256_setzero_pd(); + let mut im1: __m256d = _mm256_setzero_pd(); + let mut re2: __m256d = _mm256_setzero_pd(); + let mut im2: __m256d = _mm256_setzero_pd(); + + let mut u_ptr: *const f64 = u.as_ptr(); + let mut v_ptr: *const f64 = v.as_ptr(); + + for _ in 0..nrows { + let ur: __m256d = _mm256_loadu_pd(u_ptr); + let ui: __m256d = _mm256_loadu_pd(u_ptr.add(4)); + + let ar: __m256d = _mm256_loadu_pd(v_ptr); + let ai: __m256d = _mm256_loadu_pd(v_ptr.add(4)); + let br: __m256d = _mm256_loadu_pd(v_ptr.add(8)); + let bi: __m256d = _mm256_loadu_pd(v_ptr.add(12)); + + // re1 = re1 - ui*ai; re2 = re2 - ui*bi; + re1 = _mm256_fmsub_pd(ui, ai, re1); + re2 = _mm256_fmsub_pd(ui, bi, re2); + // im1 = im1 + ur*ai; im2 = im2 + ur*bi; + im1 = _mm256_fmadd_pd(ur, ai, im1); + im2 = _mm256_fmadd_pd(ur, bi, im2); + // re1 = re1 - ur*ar; re2 = re2 - ur*br; + re1 = _mm256_fmsub_pd(ur, ar, re1); + re2 = _mm256_fmsub_pd(ur, br, re2); + // im1 = im1 + ui*ar; im2 = im2 + ui*br; + im1 = _mm256_fmadd_pd(ui, ar, im1); + im2 = _mm256_fmadd_pd(ui, br, im2); + + u_ptr = u_ptr.add(8); + v_ptr = v_ptr.add(16); + } + + _mm256_storeu_pd(dst.as_mut_ptr(), re1); + _mm256_storeu_pd(dst.as_mut_ptr().add(4), im1); + _mm256_storeu_pd(dst.as_mut_ptr().add(8), re2); + _mm256_storeu_pd(dst.as_mut_ptr().add(12), im2); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) { + use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd}; + + #[cfg(debug_assertions)] + { + assert_eq!(dst.len(), 16, "dst must have 16 doubles"); + assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles"); + assert!( + v.len() >= nrows * 16, + "v must be at least nrows * 16 doubles" + ); + } + + unsafe { + let mut re1: __m256d = _mm256_setzero_pd(); + let mut im1: __m256d = _mm256_setzero_pd(); + + let mut u_ptr: *const f64 = u.as_ptr(); + let mut v_ptr: *const f64 = v.as_ptr().add(8); // Offset to 2nd column + + for _ in 0..nrows { + let ur: __m256d = _mm256_loadu_pd(u_ptr); + let ui: __m256d = _mm256_loadu_pd(u_ptr.add(4)); + + let ar: __m256d = _mm256_loadu_pd(v_ptr); + let ai: __m256d = _mm256_loadu_pd(v_ptr.add(4)); + + // re1 = re1 - ui*ai; re2 = re2 - ui*bi; + re1 = _mm256_fmsub_pd(ui, ai, re1); + // im1 = im1 + ur*ai; im2 = im2 + ur*bi; + im1 = _mm256_fmadd_pd(ur, ai, im1); + // re1 = re1 - ur*ar; re2 = re2 - ur*br; + re1 = _mm256_fmsub_pd(ur, ar, re1); + // im1 = im1 + ui*ar; im2 = im2 + ui*br; + im1 = _mm256_fmadd_pd(ui, ar, im1); + + u_ptr = u_ptr.add(8); + v_ptr = v_ptr.add(16); + } + + _mm256_storeu_pd(dst.as_mut_ptr(), re1); + _mm256_storeu_pd(dst.as_mut_ptr().add(4), im1); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/reim4/mod.rs b/poulpy-backend/src/cpu_fft64_avx/reim4/mod.rs new file mode 100644 index 0000000..ca49ff0 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/reim4/mod.rs @@ -0,0 +1,3 @@ +mod arithmetic_avx; + +pub(crate) use arithmetic_avx::*; diff --git a/poulpy-backend/src/cpu_fft64_avx/scratch.rs b/poulpy-backend/src/cpu_fft64_avx/scratch.rs new file mode 100644 index 0000000..c3975b3 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/scratch.rs @@ -0,0 +1,261 @@ +use std::marker::PhantomData; + +use poulpy_hal::{ + DEFAULTALIGN, alloc_aligned, + api::ScratchFromBytes, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, + oep::{ + ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, + TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, + TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, + VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, + }, +}; + +use crate::cpu_fft64_avx::FFT64Avx; + +unsafe impl ScratchOwnedAllocImpl for FFT64Avx { + fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned { + let data: Vec = alloc_aligned(size); + ScratchOwned { + data, + _phantom: PhantomData, + } + } +} + +unsafe impl ScratchOwnedBorrowImpl for FFT64Avx +where + B: ScratchFromBytesImpl, +{ + fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned) -> &mut Scratch { + Scratch::from_bytes(&mut scratch.data) + } +} + +unsafe impl ScratchFromBytesImpl for FFT64Avx { + fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch { + unsafe { &mut *(data as *mut [u8] as *mut Scratch) } + } +} + +unsafe impl ScratchAvailableImpl for FFT64Avx { + fn scratch_available_impl(scratch: &Scratch) -> usize { + let ptr: *const u8 = scratch.data.as_ptr(); + let self_len: usize = scratch.data.len(); + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + self_len.saturating_sub(aligned_offset) + } +} + +unsafe impl TakeSliceImpl for FFT64Avx +where + B: ScratchFromBytesImpl, +{ + fn take_slice_impl(scratch: &mut Scratch, len: usize) -> (&mut [T], &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::()); + + unsafe { + ( + &mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)), + Scratch::from_bytes(rem_slice), + ) + } + } +} + +unsafe impl TakeScalarZnxImpl for FFT64Avx +where + B: ScratchFromBytesImpl, +{ + fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); + ( + ScalarZnx::from_data(take_slice, n, cols), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeSvpPPolImpl for FFT64Avx +where + B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, +{ + fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); + ( + SvpPPol::from_data(take_slice, n, cols), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxImpl for FFT64Avx +where + B: ScratchFromBytesImpl, +{ + fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); + ( + VecZnx::from_data(take_slice, n, cols, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxBigImpl for FFT64Avx +where + B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, +{ + fn take_vec_znx_big_impl( + scratch: &mut Scratch, + n: usize, + cols: usize, + size: usize, + ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vec_znx_big_alloc_bytes_impl(n, cols, size), + ); + ( + VecZnxBig::from_data(take_slice, n, cols, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxDftImpl for FFT64Avx +where + B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, +{ + fn take_vec_znx_dft_impl( + scratch: &mut Scratch, + n: usize, + cols: usize, + size: usize, + ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vec_znx_dft_alloc_bytes_impl(n, cols, size), + ); + + ( + VecZnxDft::from_data(take_slice, n, cols, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxDftSliceImpl for FFT64Avx +where + B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, +{ + fn take_vec_znx_dft_slice_impl( + scratch: &mut Scratch, + len: usize, + n: usize, + cols: usize, + size: usize, + ) -> (Vec>, &mut Scratch) { + let mut scratch: &mut Scratch = scratch; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } +} + +unsafe impl TakeVecZnxSliceImpl for FFT64Avx +where + B: ScratchFromBytesImpl + TakeVecZnxImpl, +{ + fn take_vec_znx_slice_impl( + scratch: &mut Scratch, + len: usize, + n: usize, + cols: usize, + size: usize, + ) -> (Vec>, &mut Scratch) { + let mut scratch: &mut Scratch = scratch; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } +} + +unsafe impl TakeVmpPMatImpl for FFT64Avx +where + B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, +{ + fn take_vmp_pmat_impl( + scratch: &mut Scratch, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), + ); + ( + VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeMatZnxImpl for FFT64Avx +where + B: ScratchFromBytesImpl, +{ + fn take_mat_znx_impl( + scratch: &mut Scratch, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (MatZnx<&mut [u8]>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), + ); + ( + MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { + let ptr: *mut u8 = data.as_mut_ptr(); + let self_len: usize = data.len(); + + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + let aligned_len: usize = self_len.saturating_sub(aligned_offset); + + if let Some(rem_len) = aligned_len.checked_sub(take_len) { + unsafe { + let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len); + let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + + let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); + + (take_slice, rem_slice) + } + } else { + panic!( + "Attempted to take {} from scratch with {} aligned bytes left", + take_len, aligned_len, + ); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/svp.rs b/poulpy-backend/src/cpu_fft64_avx/svp.rs new file mode 100644 index 0000000..f505597 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/svp.rs @@ -0,0 +1,66 @@ +use poulpy_hal::{ + layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}, + oep::{ + SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, + SvpPrepareImpl, + }, + reference::fft64::svp::{svp_apply_dft_to_dft, svp_apply_dft_to_dft_inplace, svp_prepare}, +}; + +use crate::cpu_fft64_avx::{FFT64Avx, module::FFT64ModuleHandle}; + +unsafe impl SvpPPolFromBytesImpl for FFT64Avx { + fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { + SvpPPolOwned::from_bytes(n, cols, bytes) + } +} + +unsafe impl SvpPPolAllocImpl for FFT64Avx { + fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned { + SvpPPolOwned::alloc(n, cols) + } +} + +unsafe impl SvpPPolAllocBytesImpl for FFT64Avx { + fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + Self::layout_prep_word_count() * n * cols * size_of::() + } +} + +unsafe impl SvpPrepareImpl for FFT64Avx { + fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef, + { + svp_prepare(module.get_fft_table(), res, res_col, a, a_col); + } +} + +unsafe impl SvpApplyDftToDftImpl for FFT64Avx { + fn svp_apply_dft_to_dft_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxDftToMut, + A: SvpPPolToRef, + B: VecZnxDftToRef, + { + svp_apply_dft_to_dft(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl SvpApplyDftToDftInplaceImpl for FFT64Avx { + fn svp_apply_dft_to_dft_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + { + svp_apply_dft_to_dft_inplace(res, res_col, a, a_col); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/tests.rs b/poulpy-backend/src/cpu_fft64_avx/tests.rs new file mode 100644 index 0000000..2b4532d --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/tests.rs @@ -0,0 +1,117 @@ +use poulpy_hal::{backend_test_suite, cross_backend_test_suite}; + +cross_backend_test_suite! { + mod vec_znx, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_fft64_avx::FFT64Avx, + size = 1 << 5, + basek = 12, + tests = { + test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add, + test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace, + test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar, + test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace, + test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub, + test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace, + test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace, + test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar, + test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace, + test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh, + test_vec_znx_rsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh_inplace, + test_vec_znx_lsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh, + test_vec_znx_lsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh_inplace, + test_vec_znx_negate => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate, + test_vec_znx_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate_inplace, + test_vec_znx_rotate => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate, + test_vec_znx_rotate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate_inplace, + test_vec_znx_automorphism => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism, + test_vec_znx_automorphism_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism_inplace, + test_vec_znx_mul_xp_minus_one => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one, + test_vec_znx_mul_xp_minus_one_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one_inplace, + test_vec_znx_normalize => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize, + test_vec_znx_normalize_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize_inplace, + test_vec_znx_switch_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_switch_ring, + test_vec_znx_split_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_split_ring, + test_vec_znx_copy => poulpy_hal::test_suite::vec_znx::test_vec_znx_copy, + } +} + +cross_backend_test_suite! { + mod svp, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_fft64_avx::FFT64Avx, + size = 1 << 5, + basek = 12, + tests = { + test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft, + test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace, + } +} + +cross_backend_test_suite! { + mod vec_znx_big, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_fft64_avx::FFT64Avx, + size = 1 << 5, + basek = 12, + tests = { + test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add, + test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace, + test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small, + test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace, + test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub, + test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace, + test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism, + test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace, + test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate, + test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace, + test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize, + test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace, + test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a, + test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace, + test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b, + test_vec_znx_big_sub_small_b_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b_inplace, + } +} + +cross_backend_test_suite! { + mod vec_znx_dft, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_fft64_avx::FFT64Avx, + size = 1 << 5, + basek = 12, + tests = { + test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add, + test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace, + test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub, + test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace, + test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace, + test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply, + test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume, + test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa, + } +} + +cross_backend_test_suite! { + mod vmp, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_fft64_avx::FFT64Avx, + size = 1 << 5, + basek = 12, + tests = { + test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft, + test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add, + } +} + +backend_test_suite! { + mod sampling, + backend = crate::cpu_fft64_avx::FFT64Avx, + size = 1 << 12, + tests = { + test_vec_znx_fill_uniform => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_uniform, + test_vec_znx_fill_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_normal, + test_vec_znx_add_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_normal, + test_vec_znx_big_sub_small_b_inplace => poulpy_hal::reference::fft64::vec_znx_big::test_vec_znx_big_add_normal, + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs new file mode 100644 index 0000000..d61e021 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs @@ -0,0 +1,538 @@ +use poulpy_hal::{ + api::{ + TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes, + VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes, + }, + layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, + oep::{ + TakeSliceImpl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl, + VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl, + VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl, + VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, + VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, + VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, + VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, + VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + }, + reference::vec_znx::{ + vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace, + vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy, + vec_znx_fill_normal_ref, vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, + vec_znx_merge_rings, vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one, vec_znx_mul_xp_minus_one_inplace, + vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize, + vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, + vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, + vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar, + vec_znx_sub_scalar_inplace, vec_znx_switch_ring, + }, + source::Source, +}; + +use crate::cpu_fft64_avx::FFT64Avx; + +unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Avx { + fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_normalize_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxNormalizeImpl for FFT64Avx +where + Self: TakeSliceImpl + VecZnxNormalizeTmpBytesImpl, +{ + fn vec_znx_normalize_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_normalize::(basek, res, res_col, a, a_col, carry); + } +} + +unsafe impl VecZnxNormalizeInplaceImpl for FFT64Avx +where + Self: TakeSliceImpl + VecZnxNormalizeTmpBytesImpl, +{ + fn vec_znx_normalize_inplace_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_normalize_inplace::(basek, res, res_col, carry); + } +} + +unsafe impl VecZnxAddImpl for FFT64Avx { + fn vec_znx_add_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + vec_znx_add::(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxAddInplaceImpl for FFT64Avx { + fn vec_znx_add_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_add_inplace::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxAddScalarInplaceImpl for FFT64Avx { + fn vec_znx_add_scalar_inplace_impl( + _module: &Module, + res: &mut R, + res_col: usize, + res_limb: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + vec_znx_add_scalar_inplace::(res, res_col, res_limb, a, a_col); + } +} + +unsafe impl VecZnxAddScalarImpl for FFT64Avx { + fn vec_znx_add_scalar_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + b_limb: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef, + { + vec_znx_add_scalar::(res, res_col, a, a_col, b, b_col, b_limb); + } +} + +unsafe impl VecZnxSubImpl for FFT64Avx { + fn vec_znx_sub_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + vec_znx_sub::(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxSubABInplaceImpl for FFT64Avx { + fn vec_znx_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_sub_ab_inplace::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxSubBAInplaceImpl for FFT64Avx { + fn vec_znx_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_sub_ba_inplace::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxSubScalarImpl for FFT64Avx { + fn vec_znx_sub_scalar_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + b_limb: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef, + { + vec_znx_sub_scalar::(res, res_col, a, a_col, b, b_col, b_limb); + } +} + +unsafe impl VecZnxSubScalarInplaceImpl for FFT64Avx { + fn vec_znx_sub_scalar_inplace_impl( + _module: &Module, + res: &mut R, + res_col: usize, + res_limb: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + vec_znx_sub_scalar_inplace::(res, res_col, res_limb, a, a_col); + } +} + +unsafe impl VecZnxNegateImpl for FFT64Avx { + fn vec_znx_negate_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_negate::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxNegateInplaceImpl for FFT64Avx { + fn vec_znx_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + vec_znx_negate_inplace::(res, res_col); + } +} + +unsafe impl VecZnxLshTmpBytesImpl for FFT64Avx { + fn vec_znx_lsh_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_lsh_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxRshTmpBytesImpl for FFT64Avx { + fn vec_znx_rsh_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_rsh_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxLshImpl for FFT64Avx +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_lsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry); + } +} + +unsafe impl VecZnxLshInplaceImpl for FFT64Avx +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_lsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + a: &mut A, + a_col: usize, + scratch: &mut Scratch, + ) where + A: VecZnxToMut, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry); + } +} + +unsafe impl VecZnxRshImpl for FFT64Avx +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_rsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry); + } +} + +unsafe impl VecZnxRshInplaceImpl for FFT64Avx +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_rsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + a: &mut A, + a_col: usize, + scratch: &mut Scratch, + ) where + A: VecZnxToMut, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry); + } +} + +unsafe impl VecZnxRotateImpl for FFT64Avx { + fn vec_znx_rotate_impl(_module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_rotate::(p, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxRotateInplaceTmpBytesImpl for FFT64Avx +where + Scratch: TakeSlice, +{ + fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_rotate_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxRotateInplaceImpl for FFT64Avx +where + Scratch: TakeSlice, + Self: VecZnxRotateInplaceTmpBytesImpl, +{ + fn vec_znx_rotate_inplace_impl(module: &Module, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_rotate_inplace_tmp_bytes() / size_of::()); + vec_znx_rotate_inplace::(p, res, res_col, tmp); + } +} + +unsafe impl VecZnxAutomorphismImpl for FFT64Avx { + fn vec_znx_automorphism_impl(_module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_automorphism::(p, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxAutomorphismInplaceTmpBytesImpl for FFT64Avx { + fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_automorphism_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxAutomorphismInplaceImpl for FFT64Avx +where + Scratch: TakeSlice, + Self: VecZnxAutomorphismInplaceTmpBytesImpl, +{ + fn vec_znx_automorphism_inplace_impl( + module: &Module, + p: i64, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_automorphism_inplace_tmp_bytes() / size_of::()); + vec_znx_automorphism_inplace::(p, res, res_col, tmp); + } +} + +unsafe impl VecZnxMulXpMinusOneImpl for FFT64Avx { + fn vec_znx_mul_xp_minus_one_impl(_module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_mul_xp_minus_one::(p, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxMulXpMinusOneInplaceTmpBytesImpl for FFT64Avx +where + Scratch: TakeSlice, + Self: VecZnxMulXpMinusOneImpl, +{ + fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_mul_xp_minus_one_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxMulXpMinusOneInplaceImpl for FFT64Avx { + fn vec_znx_mul_xp_minus_one_inplace_impl( + module: &Module, + p: i64, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes() / size_of::()); + vec_znx_mul_xp_minus_one_inplace::(p, res, res_col, tmp); + } +} + +unsafe impl VecZnxSplitRingTmpBytesImpl for FFT64Avx { + fn vec_znx_split_ring_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_split_ring_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxSplitRingImpl for FFT64Avx +where + Module: VecZnxSplitRingTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_split_ring_impl( + module: &Module, + res: &mut [R], + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_split_ring_tmp_bytes() / size_of::()); + vec_znx_split_ring::(res, res_col, a, a_col, tmp); + } +} + +unsafe impl VecZnxMergeRingsTmpBytesImpl for FFT64Avx { + fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_merge_rings_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxMergeRingsImpl for FFT64Avx +where + Module: VecZnxMergeRingsTmpBytes, +{ + fn vec_znx_merge_rings_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &[A], + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_merge_rings_tmp_bytes() / size_of::()); + vec_znx_merge_rings::(res, res_col, a, a_col, tmp); + } +} + +unsafe impl VecZnxSwitchRingImpl for FFT64Avx +where + Self: VecZnxCopyImpl, +{ + fn vec_znx_switch_ring_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_switch_ring::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxCopyImpl for FFT64Avx { + fn vec_znx_copy_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_copy::(res, res_col, a, a_col) + } +} + +unsafe impl VecZnxFillUniformImpl for FFT64Avx { + fn vec_znx_fill_uniform_impl(_module: &Module, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + where + R: VecZnxToMut, + { + vec_znx_fill_uniform_ref(basek, res, res_col, source) + } +} + +unsafe impl VecZnxFillNormalImpl for FFT64Avx { + fn vec_znx_fill_normal_impl( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut, + { + vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source); + } +} + +unsafe impl VecZnxAddNormalImpl for FFT64Avx { + fn vec_znx_add_normal_impl( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut, + { + vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs new file mode 100644 index 0000000..f8d9180 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs @@ -0,0 +1,332 @@ +use crate::cpu_fft64_avx::FFT64Avx; +use poulpy_hal::{ + api::{TakeSlice, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNormalizeTmpBytes}, + layouts::{ + Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, + ZnxInfos, ZnxView, ZnxViewMut, + }, + oep::{ + TakeSliceImpl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, + VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, + VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, + VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, + VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, + VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + }, + reference::{ + fft64::vec_znx_big::{ + vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small, + vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace, + vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize, + vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace, + vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace, + }, + znx::{znx_copy_ref, znx_zero_ref}, + }, + source::Source, +}; + +unsafe impl VecZnxBigAllocBytesImpl for FFT64Avx { + fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + Self::layout_big_word_count() * n * cols * size * size_of::() + } +} + +unsafe impl VecZnxBigAllocImpl for FFT64Avx { + fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned { + VecZnxBig::alloc(n, cols, size) + } +} + +unsafe impl VecZnxBigFromBytesImpl for FFT64Avx { + fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + VecZnxBig::from_bytes(n, cols, size, bytes) + } +} + +unsafe impl VecZnxBigFromSmallImpl for FFT64Avx { + fn vec_znx_big_from_small_impl(res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let mut res: VecZnxBig<&mut [u8], FFT64Avx> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let min_size: usize = res_size.min(a_size); + + for j in 0..min_size { + znx_copy_ref(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in min_size..res_size { + znx_zero_ref(res.at_mut(res_col, j)); + } + } +} + +unsafe impl VecZnxBigAddNormalImpl for FFT64Avx { + fn add_normal_impl>( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) { + vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source); + } +} + +unsafe impl VecZnxBigAddImpl for FFT64Avx { + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + vec_znx_big_add(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigAddInplaceImpl for FFT64Avx { + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_add_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigAddSmallImpl for FFT64Avx { + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_small_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + vec_znx_big_add_small(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64Avx { + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_small_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + vec_znx_big_add_small_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigSubImpl for FFT64Avx { + /// Subtracts `a` to `b` and stores the result on `c`. + fn vec_znx_big_sub_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + vec_znx_big_sub(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigSubABInplaceImpl for FFT64Avx { + /// Subtracts `a` from `b` and stores the result on `b`. + fn vec_znx_big_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_sub_ab_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigSubBAInplaceImpl for FFT64Avx { + /// Subtracts `b` from `a` and stores the result on `b`. + fn vec_znx_big_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_sub_ba_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigSubSmallAImpl for FFT64Avx { + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_a_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef, + { + vec_znx_big_sub_small_a(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64Avx { + /// Subtracts `a` from `res` and stores the result on `res`. + fn vec_znx_big_sub_small_a_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + vec_znx_big_sub_small_a_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigSubSmallBImpl for FFT64Avx { + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_b_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + vec_znx_big_sub_small_b(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64Avx { + /// Subtracts `res` from `a` and stores the result on `res`. + fn vec_znx_big_sub_small_b_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + vec_znx_big_sub_small_b_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigNegateImpl for FFT64Avx { + fn vec_znx_big_negate_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_negate(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigNegateInplaceImpl for FFT64Avx { + fn vec_znx_big_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxBigToMut, + { + vec_znx_big_negate_inplace(res, res_col); + } +} + +unsafe impl VecZnxBigNormalizeTmpBytesImpl for FFT64Avx { + fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_big_normalize_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxBigNormalizeImpl for FFT64Avx +where + Self: TakeSliceImpl, +{ + fn vec_znx_big_normalize_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxBigToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); + vec_znx_big_normalize(basek, res, res_col, a, a_col, carry); + } +} + +unsafe impl VecZnxBigAutomorphismImpl for FFT64Avx { + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. + fn vec_znx_big_automorphism_impl(_module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_automorphism(p, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigAutomorphismInplaceTmpBytesImpl for FFT64Avx { + fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_big_automorphism_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxBigAutomorphismInplaceImpl for FFT64Avx +where + Module: VecZnxBigAutomorphismInplaceTmpBytes, +{ + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. + fn vec_znx_big_automorphism_inplace_impl( + module: &Module, + p: i64, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); + vec_znx_big_automorphism_inplace(p, res, res_col, tmp); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs new file mode 100644 index 0000000..bca555e --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs @@ -0,0 +1,186 @@ +use poulpy_hal::{ + layouts::{ + Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, + VecZnxToRef, + }, + oep::{ + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, + VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, + }, + reference::fft64::vec_znx_dft::{ + vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, + vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume, + vec_znx_idft_apply_tmpa, + }, +}; + +use crate::cpu_fft64_avx::{FFT64Avx, module::FFT64ModuleHandle}; + +unsafe impl VecZnxDftFromBytesImpl for FFT64Avx { + fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + VecZnxDft::, Self>::from_bytes(n, cols, size, bytes) + } +} + +unsafe impl VecZnxDftAllocBytesImpl for FFT64Avx { + fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() + } +} + +unsafe impl VecZnxDftAllocImpl for FFT64Avx { + fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned { + VecZnxDftOwned::alloc(n, cols, size) + } +} + +unsafe impl VecZnxIdftApplyTmpBytesImpl for FFT64Avx { + fn vec_znx_idft_apply_tmp_bytes_impl(_module: &Module) -> usize { + 0 + } +} + +unsafe impl VecZnxIdftApplyImpl for FFT64Avx { + fn vec_znx_idft_apply_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + _scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + A: VecZnxDftToRef, + { + vec_znx_idft_apply(module.get_ifft_table(), res, res_col, a, a_col); + } +} + +unsafe impl VecZnxIdftApplyTmpAImpl for FFT64Avx { + fn vec_znx_idft_apply_tmpa_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut, + { + vec_znx_idft_apply_tmpa(module.get_ifft_table(), res, res_col, a, a_col); + } +} + +unsafe impl VecZnxIdftApplyConsumeImpl for FFT64Avx { + fn vec_znx_idft_apply_consume_impl(module: &Module, res: VecZnxDft) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut, + { + vec_znx_idft_apply_consume(module.get_ifft_table(), res) + } +} + +unsafe impl VecZnxDftApplyImpl for FFT64Avx { + fn vec_znx_dft_apply_impl( + module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxToRef, + { + vec_znx_dft_apply(module.get_fft_table(), step, offset, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftAddImpl for FFT64Avx { + fn vec_znx_dft_add_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: VecZnxDftToRef, + { + vec_znx_dft_add(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxDftAddInplaceImpl for FFT64Avx { + fn vec_znx_dft_add_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_add_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftSubImpl for FFT64Avx { + fn vec_znx_dft_sub_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: VecZnxDftToRef, + { + vec_znx_dft_sub(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxDftSubABInplaceImpl for FFT64Avx { + fn vec_znx_dft_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftSubBAInplaceImpl for FFT64Avx { + fn vec_znx_dft_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftCopyImpl for FFT64Avx { + fn vec_znx_dft_copy_impl( + _module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_copy(step, offset, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftZeroImpl for FFT64Avx { + fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R) + where + R: VecZnxDftToMut, + { + vec_znx_dft_zero(res); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/vmp.rs b/poulpy-backend/src/cpu_fft64_avx/vmp.rs new file mode 100644 index 0000000..6b87ce1 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/vmp.rs @@ -0,0 +1,143 @@ +use poulpy_hal::{ + api::{TakeSlice, VmpPrepareTmpBytes}, + layouts::{ + Backend, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatOwned, + VmpPMatToMut, VmpPMatToRef, ZnxInfos, + }, + oep::{ + VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl, + VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, + }, + reference::fft64::vmp::{ + vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes, + }, +}; + +use crate::cpu_fft64_avx::{FFT64Avx, module::FFT64ModuleHandle}; + +unsafe impl VmpPMatAllocBytesImpl for FFT64Avx { + fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() + } +} + +unsafe impl VmpPMatAllocImpl for FFT64Avx { + fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { + VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size) + } +} + +unsafe impl VmpApplyDftToDftImpl for FFT64Avx +where + Scratch: TakeSlice, + FFT64Avx: VmpApplyDftToDftTmpBytesImpl, +{ + fn vmp_apply_dft_to_dft_impl(module: &Module, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, + { + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); + let pmat: VmpPMat<&[u8], Self> = pmat.to_ref(); + + let (tmp, _) = scratch.take_slice( + Self::vmp_apply_dft_to_dft_tmp_bytes_impl( + module, + res.size(), + a.size(), + pmat.rows(), + pmat.cols_in(), + pmat.cols_out(), + pmat.size(), + ) / size_of::(), + ); + vmp_apply_dft_to_dft(&mut res, &a, &pmat, tmp); + } +} + +unsafe impl VmpApplyDftToDftAddImpl for FFT64Avx +where + Scratch: TakeSlice, + FFT64Avx: VmpApplyDftToDftTmpBytesImpl, +{ + fn vmp_apply_dft_to_dft_add_impl( + module: &Module, + res: &mut R, + a: &A, + pmat: &C, + limb_offset: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, + { + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); + let pmat: VmpPMat<&[u8], Self> = pmat.to_ref(); + + let (tmp, _) = scratch.take_slice( + Self::vmp_apply_dft_to_dft_tmp_bytes_impl( + module, + res.size(), + a.size(), + pmat.rows(), + pmat.cols_in(), + pmat.cols_out(), + pmat.size(), + ) / size_of::(), + ); + vmp_apply_dft_to_dft_add(&mut res, &a, &pmat, limb_offset * pmat.cols_out(), tmp); + } +} + +unsafe impl VmpPrepareTmpBytesImpl for FFT64Avx { + fn vmp_prepare_tmp_bytes_impl(module: &Module, _rows: usize, _cols_in: usize, _cols_out: usize, _size: usize) -> usize { + vmp_prepare_tmp_bytes(module.n()) + } +} + +unsafe impl VmpPrepareImpl for FFT64Avx { + fn vmp_prepare_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: VmpPMatToMut, + A: MatZnxToRef, + { + {} + let mut res: VmpPMat<&mut [u8], Self> = res.to_mut(); + let a: MatZnx<&[u8]> = a.to_ref(); + let (tmp, _) = + scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()) / size_of::()); + vmp_prepare(module.get_fft_table(), &mut res, &a, tmp); + } +} + +unsafe impl VmpApplyDftToDftTmpBytesImpl for FFT64Avx { + fn vmp_apply_dft_to_dft_tmp_bytes_impl( + _module: &Module, + _res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + _b_cols_out: usize, + _b_size: usize, + ) -> usize { + vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in) + } +} + +unsafe impl VmpApplyDftToDftAddTmpBytesImpl for FFT64Avx { + fn vmp_apply_dft_to_dft_add_tmp_bytes_impl( + _module: &Module, + _res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + _b_cols_out: usize, + _b_size: usize, + ) -> usize { + vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in) + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/zn.rs b/poulpy-backend/src/cpu_fft64_avx/zn.rs new file mode 100644 index 0000000..033c3e2 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/zn.rs @@ -0,0 +1,73 @@ +use poulpy_hal::{ + api::TakeSlice, + layouts::{Scratch, ZnToMut}, + oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl}, + reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes}, + source::Source, +}; + +use crate::cpu_fft64_avx::FFT64Avx; + +unsafe impl ZnNormalizeTmpBytesImpl for FFT64Avx { + fn zn_normalize_tmp_bytes_impl(n: usize) -> usize { + zn_normalize_tmp_bytes(n) + } +} + +unsafe impl ZnNormalizeInplaceImpl for FFT64Avx +where + Self: TakeSliceImpl, +{ + fn zn_normalize_inplace_impl(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) + where + R: ZnToMut, + { + let (carry, _) = scratch.take_slice(n); + zn_normalize_inplace::(n, basek, res, res_col, carry); + } +} + +unsafe impl ZnFillUniformImpl for FFT64Avx { + fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + where + R: ZnToMut, + { + zn_fill_uniform(n, basek, res, res_col, source); + } +} + +unsafe impl ZnFillNormalImpl for FFT64Avx { + #[allow(clippy::too_many_arguments)] + fn zn_fill_normal_impl( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut, + { + zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound); + } +} + +unsafe impl ZnAddNormalImpl for FFT64Avx { + #[allow(clippy::too_many_arguments)] + fn zn_add_normal_impl( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut, + { + zn_add_normal(n, basek, res, res_col, k, source, sigma, bound); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/add.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/add.rs new file mode 100644 index 0000000..47f7552 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/add.rs @@ -0,0 +1,76 @@ +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_add_avx(res: &mut [i64], a: &[i64], b: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + assert_eq!(res.len(), b.len()); + } + + use core::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256}; + + let n: usize = res.len(); + + let span: usize = n >> 2; + + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut aa: *const __m256i = a.as_ptr() as *const __m256i; + let mut bb: *const __m256i = b.as_ptr() as *const __m256i; + + unsafe { + for _ in 0..span { + let sum: __m256i = _mm256_add_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(bb)); + _mm256_storeu_si256(rr, sum); + rr = rr.add(1); + aa = aa.add(1); + bb = bb.add(1); + } + } + + // tail + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_add_ref; + + znx_add_ref(&mut res[span << 2..], &a[span << 2..], &b[span << 2..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_add_inplace_avx(res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + use core::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256}; + + let n: usize = res.len(); + + let span: usize = n >> 2; + + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut aa: *const __m256i = a.as_ptr() as *const __m256i; + + unsafe { + for _ in 0..span { + let sum: __m256i = _mm256_add_epi64(_mm256_loadu_si256(rr), _mm256_loadu_si256(aa)); + _mm256_storeu_si256(rr, sum); + rr = rr.add(1); + aa = aa.add(1); + } + } + + // tail + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_add_inplace_ref; + + znx_add_inplace_ref(&mut res[span << 2..], &a[span << 2..]); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/automorphism.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/automorphism.rs new file mode 100644 index 0000000..b1b7d82 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/automorphism.rs @@ -0,0 +1,133 @@ +use core::arch::x86_64::*; + +#[inline] +fn inv_mod_pow2(p: usize, bits: u32) -> usize { + // Compute p^{-1} mod 2^bits (p must be odd) through Hensel lifting. + debug_assert!(p % 2 == 1); + let mut x: usize = 1usize; // inverse mod 2 + let mut i: u32 = 1; + while i < bits { + // x <- x * (2 - p*x) mod 2^(2^i) (wrapping arithmetic) + x = x.wrapping_mul(2usize.wrapping_sub(p.wrapping_mul(x))); + i <<= 1; + } + x & ((1usize << bits) - 1) +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +pub fn znx_automorphism_avx(p: i64, res: &mut [i64], a: &[i64]) { + debug_assert_eq!(res.len(), a.len()); + let n: usize = res.len(); + if n == 0 { + return; + } + debug_assert!(n.is_power_of_two(), "n must be power of two"); + debug_assert!(p & 1 == 1, "p must be odd (invertible mod 2n)"); + + if n < 4 { + use poulpy_hal::reference::znx::znx_automorphism_ref; + + znx_automorphism_ref(p, res, a); + return; + } + + unsafe { + let two_n: usize = n << 1; + let span: usize = n >> 2; + let bits: u32 = (two_n as u64).trailing_zeros(); + let mask_2n: usize = two_n - 1; + let mask_1n: usize = n - 1; + + // p mod 2n (positive) + let p_2n: usize = (((p & mask_2n as i64) + two_n as i64) as usize) & mask_2n; + + // p^-1 mod 2n + let inv: usize = inv_mod_pow2(p_2n, bits); + + // Broadcast constants + let n_minus1_vec: __m256i = _mm256_set1_epi64x((n as i64) - 1); + let mask_2n_vec: __m256i = _mm256_set1_epi64x(mask_2n as i64); + let mask_1n_vec: __m256i = _mm256_set1_epi64x(mask_1n as i64); + + // Lane offsets [0, inv, 2*inv, 3*inv] (mod 2n) + let lane_offsets: __m256i = _mm256_set_epi64x( + ((inv * 3) & mask_2n) as i64, + ((inv * 2) & mask_2n) as i64, + inv as i64, + 0i64, + ); + + // t_base = (j * inv) mod 2n. + let mut t_base: usize = 0; + let step: usize = (inv << 2) & mask_2n; + + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let aa: *const i64 = a.as_ptr(); + + for _ in 0..span { + // t_vec = (t_base + [0, inv, 2*inv, 3*inv]) & (2n-1) + let t_base_vec: __m256i = _mm256_set1_epi64x(t_base as i64); + let t_vec: __m256i = _mm256_and_si256(_mm256_add_epi64(t_base_vec, lane_offsets), mask_2n_vec); + + // idx = t_vec & (n-1) + let idx_vec: __m256i = _mm256_and_si256(t_vec, mask_1n_vec); + + // sign = t >= n ? -1 : 0 (mask of all-ones where negate) + let sign_mask: __m256i = _mm256_cmpgt_epi64(t_vec, n_minus1_vec); + + // gather a[idx] (scale = 8 bytes per i64) + let vals: __m256i = _mm256_i64gather_epi64(aa, idx_vec, 8); + + // Conditional negate: (vals ^ sign_mask) - sign_mask + let vals_x: __m256i = _mm256_xor_si256(vals, sign_mask); + let out: __m256i = _mm256_sub_epi64(vals_x, sign_mask); + + // store to res[j..j+4] + _mm256_storeu_si256(rr, out); + + // advance + rr = rr.add(1); + t_base = (t_base + step) & mask_2n; + } + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(all(test, any(target_arch = "x86_64", target_arch = "x86")))] +mod tests { + use poulpy_hal::reference::znx::znx_automorphism_ref; + + use super::*; + + #[target_feature(enable = "avx2", enable = "fma")] + fn test_znx_automorphism_internal() { + let a: [i64; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + + let p: i64 = -5; + + let mut r0: Vec = vec![0i64; a.len()]; + let mut r1: Vec = vec![0i64; a.len()]; + + znx_automorphism_ref(p, &mut r0, &a); + znx_automorphism_avx(p, &mut r1, &a); + + assert_eq!(r0, r1); + } + + #[test] + fn test_znx_automorphism_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + test_znx_automorphism_internal(); + } + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/mod.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/mod.rs new file mode 100644 index 0000000..70ab48b --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/mod.rs @@ -0,0 +1,13 @@ +mod add; +mod automorphism; +mod neg; +mod normalization; +mod sub; +mod switch_ring; + +pub(crate) use add::*; +pub(crate) use automorphism::*; +pub(crate) use neg::*; +pub(crate) use normalization::*; +pub(crate) use sub::*; +pub(crate) use switch_ring::*; diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/neg.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/neg.rs new file mode 100644 index 0000000..1e922bd --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/neg.rs @@ -0,0 +1,64 @@ +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_negate_avx(res: &mut [i64], src: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), src.len()) + } + + let n: usize = res.len(); + + use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_setzero_si256, _mm256_storeu_si256, _mm256_sub_epi64}; + let span: usize = n >> 2; + + unsafe { + let mut aa: *const __m256i = src.as_ptr() as *const __m256i; + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let zero: __m256i = _mm256_setzero_si256(); + for _ in 0..span { + let v: __m256i = _mm256_loadu_si256(aa); + let neg: __m256i = _mm256_sub_epi64(zero, v); + _mm256_storeu_si256(rr, neg); + aa = aa.add(1); + rr = rr.add(1); + } + } + + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_negate_ref; + + znx_negate_ref(&mut res[span << 2..], &src[span << 2..]) + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_negate_inplace_avx(res: &mut [i64]) { + let n: usize = res.len(); + + use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_setzero_si256, _mm256_storeu_si256, _mm256_sub_epi64}; + let span: usize = n >> 2; + + unsafe { + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let zero: __m256i = _mm256_setzero_si256(); + for _ in 0..span { + let v: __m256i = _mm256_loadu_si256(rr); + let neg: __m256i = _mm256_sub_epi64(zero, v); + _mm256_storeu_si256(rr, neg); + rr = rr.add(1); + } + } + + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_negate_inplace_ref; + + znx_negate_inplace_ref(&mut res[span << 2..]) + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs new file mode 100644 index 0000000..e89cd58 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs @@ -0,0 +1,1023 @@ +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::__m256i; + +/// Vector forms of those constants (broadcast to all lanes) +/// +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +fn normalize_consts_avx(basek: usize) -> (__m256i, __m256i, __m256i, __m256i) { + use std::arch::x86_64::_mm256_set1_epi64x; + + assert!((1..=63).contains(&basek)); + let mask_k: i64 = ((1u64 << basek) - 1) as i64; // 0..k-1 bits set + let sign_k: i64 = (1u64 << (basek - 1)) as i64; // bit k-1 + let topmask: i64 = (!0u64 << (64 - basek)) as i64; // top k bits set + let sh_k: __m256i = _mm256_set1_epi64x(basek as i64); + ( + _mm256_set1_epi64x(mask_k), // mask_k_vec + _mm256_set1_epi64x(sign_k), // sign_k_vec + sh_k, // shift_k_vec + _mm256_set1_epi64x(topmask), // topmask_vec + ) +} + +/// AVX2 get_digit using masks (no arithmetic shift needed): +/// digit = ((x & mask_k) ^ sign_k) - sign_k +/// +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +fn get_digit_avx(x: __m256i, mask_k: __m256i, sign_k: __m256i) -> __m256i { + use std::arch::x86_64::{_mm256_and_si256, _mm256_sub_epi64, _mm256_xor_si256}; + let low: __m256i = _mm256_and_si256(x, mask_k); + let t: __m256i = _mm256_xor_si256(low, sign_k); + _mm256_sub_epi64(t, sign_k) +} + +/// AVX2 get_carry using precomputed shift and topmask: +/// carry = (x - digit) >>_arith k +/// +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn get_carry_avx( + x: __m256i, + digit: __m256i, + basek: __m256i, // _mm256_set1_epi64x(k) + top_mask: __m256i, // (!0 << (64 - k)) broadcast +) -> __m256i { + use std::arch::x86_64::{ + __m256i, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_or_si256, _mm256_setzero_si256, _mm256_srlv_epi64, _mm256_sub_epi64, + }; + let diff: __m256i = _mm256_sub_epi64(x, digit); + let lsr: __m256i = _mm256_srlv_epi64(diff, basek); // logical >> + let neg: __m256i = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); // 0xFFFF.. where v<0 + let fill: __m256i = _mm256_and_si256(neg, top_mask); // top k bits if negative + _mm256_or_si256(lsr, fill) +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_normalize_first_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), carry.len()); + assert!(lsh < basek); + } + + use std::arch::x86_64::{_mm256_loadu_si256, _mm256_storeu_si256}; + + let n: usize = x.len(); + + let span: usize = n >> 2; + + unsafe { + let mut xx: *const __m256i = x.as_ptr() as *const __m256i; + let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; + + let (mask, sign, basek_vec, top_mask) = if lsh == 0 { + normalize_consts_avx(basek) + } else { + normalize_consts_avx(basek - lsh) + }; + + for _ in 0..span { + let xx_256: __m256i = _mm256_loadu_si256(xx); + + // (x << (64 - basek)) >> (64 - basek) + let digit_256: __m256i = get_digit_avx(xx_256, mask, sign); + + // (x - digit) >> basek + let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask); + + _mm256_storeu_si256(cc, carry_256); + + xx = xx.add(1); + cc = cc.add(1); + } + } + + // tail + if !x.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_normalize_first_step_carry_only_ref; + + znx_normalize_first_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_normalize_first_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), carry.len()); + assert!(lsh < basek); + } + + use std::arch::x86_64::{_mm256_loadu_si256, _mm256_set1_epi64x, _mm256_sllv_epi64, _mm256_storeu_si256}; + + let n: usize = x.len(); + + let span: usize = n >> 2; + + unsafe { + let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; + let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; + + if lsh == 0 { + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + + for _ in 0..span { + let xx_256: __m256i = _mm256_loadu_si256(xx); + + // (x << (64 - basek)) >> (64 - basek) + let digit_256: __m256i = get_digit_avx(xx_256, mask, sign); + + // (x - digit) >> basek + let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask); + + _mm256_storeu_si256(xx, digit_256); + _mm256_storeu_si256(cc, carry_256); + + xx = xx.add(1); + cc = cc.add(1); + } + } else { + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh); + + let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); + + for _ in 0..span { + let xx_256: __m256i = _mm256_loadu_si256(xx); + + // (x << (64 - basek)) >> (64 - basek) + let digit_256: __m256i = get_digit_avx(xx_256, mask, sign); + + // (x - digit) >> basek + let carry_256: __m256i = get_carry_avx(xx_256, digit_256, basek_vec, top_mask); + + _mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v)); + _mm256_storeu_si256(cc, carry_256); + + xx = xx.add(1); + cc = cc.add(1); + } + } + } + + // tail + if !x.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_normalize_first_step_inplace_ref; + + znx_normalize_first_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_normalize_first_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), carry.len()); + assert_eq!(a.len(), carry.len()); + assert!(lsh < basek); + } + + use std::arch::x86_64::{_mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; + + let n: usize = x.len(); + + let span: usize = n >> 2; + + unsafe { + let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; + let mut aa: *const __m256i = a.as_ptr() as *const __m256i; + let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; + + if lsh == 0 { + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + + for _ in 0..span { + let aa_256: __m256i = _mm256_loadu_si256(aa); + + // (x << (64 - basek)) >> (64 - basek) + let digit_256: __m256i = get_digit_avx(aa_256, mask, sign); + + // (x - digit) >> basek + let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask); + + _mm256_storeu_si256(xx, digit_256); + _mm256_storeu_si256(cc, carry_256); + + xx = xx.add(1); + aa = aa.add(1); + cc = cc.add(1); + } + } else { + use std::arch::x86_64::_mm256_set1_epi64x; + + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek - lsh); + + let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); + + for _ in 0..span { + let aa_256: __m256i = _mm256_loadu_si256(aa); + + // (x << (64 - basek)) >> (64 - basek) + let digit_256: __m256i = get_digit_avx(aa_256, mask, sign); + + // (x - digit) >> basek + let carry_256: __m256i = get_carry_avx(aa_256, digit_256, basek_vec, top_mask); + + _mm256_storeu_si256(xx, _mm256_sllv_epi64(digit_256, lsh_v)); + _mm256_storeu_si256(cc, carry_256); + + xx = xx.add(1); + aa = aa.add(1); + cc = cc.add(1); + } + } + } + + // tail + if !x.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_normalize_first_step_ref; + + znx_normalize_first_step_ref( + basek, + lsh, + &mut x[span << 2..], + &a[span << 2..], + &mut carry[span << 2..], + ); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2")] +pub fn znx_normalize_middle_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), carry.len()); + assert!(lsh < basek); + } + + use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; + + let n: usize = x.len(); + + let span: usize = n >> 2; + + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + + unsafe { + let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; + let mut cc: *mut __m256i = carry.as_mut_ptr() as *mut __m256i; + + if lsh == 0 { + for _ in 0..span { + let xx_256: __m256i = _mm256_loadu_si256(xx); + let cc_256: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(xx_256, mask, sign); + let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask); + + let s: __m256i = _mm256_add_epi64(d0, cc_256); + let x1: __m256i = get_digit_avx(s, mask, sign); + let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let cout: __m256i = _mm256_add_epi64(c0, c1); + + _mm256_storeu_si256(xx, x1); + _mm256_storeu_si256(cc, cout); + + xx = xx.add(1); + cc = cc.add(1); + } + } else { + use std::arch::x86_64::_mm256_set1_epi64x; + + let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh); + + let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); + + for _ in 0..span { + let xx_256: __m256i = _mm256_loadu_si256(xx); + let cc_256: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh); + let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh); + + let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); + + let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256); + let x1: __m256i = get_digit_avx(s, mask, sign); + let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let cout: __m256i = _mm256_add_epi64(c0, c1); + + _mm256_storeu_si256(xx, x1); + _mm256_storeu_si256(cc, cout); + + xx = xx.add(1); + cc = cc.add(1); + } + } + } + + if !x.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_normalize_middle_step_inplace_ref; + + znx_normalize_middle_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2")] +pub fn znx_normalize_middle_step_carry_only_avx(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), carry.len()); + assert!(lsh < basek); + } + + use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; + + let n: usize = x.len(); + + let span: usize = n >> 2; + + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + + unsafe { + let mut xx: *const __m256i = x.as_ptr() as *const __m256i; + let mut cc: *mut __m256i = carry.as_mut_ptr() as *mut __m256i; + + if lsh == 0 { + for _ in 0..span { + let xx_256: __m256i = _mm256_loadu_si256(xx); + let cc_256: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(xx_256, mask, sign); + let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec, top_mask); + + let s: __m256i = _mm256_add_epi64(d0, cc_256); + let x1: __m256i = get_digit_avx(s, mask, sign); + let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let cout: __m256i = _mm256_add_epi64(c0, c1); + + _mm256_storeu_si256(cc, cout); + + xx = xx.add(1); + cc = cc.add(1); + } + } else { + use std::arch::x86_64::_mm256_set1_epi64x; + + let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh); + + let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); + + for _ in 0..span { + let xx_256: __m256i = _mm256_loadu_si256(xx); + let cc_256: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(xx_256, mask_lsh, sign_lsh); + let c0: __m256i = get_carry_avx(xx_256, d0, basek_vec_lsh, top_mask_lsh); + + let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); + + let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256); + let x1: __m256i = get_digit_avx(s, mask, sign); + let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let cout: __m256i = _mm256_add_epi64(c0, c1); + + _mm256_storeu_si256(cc, cout); + + xx = xx.add(1); + cc = cc.add(1); + } + } + } + + if !x.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_normalize_middle_step_carry_only_ref; + + znx_normalize_middle_step_carry_only_ref(basek, lsh, &x[span << 2..], &mut carry[span << 2..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2")] +pub fn znx_normalize_middle_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), carry.len()); + assert_eq!(a.len(), carry.len()); + assert!(lsh < basek); + } + + use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; + + let n: usize = x.len(); + + let span: usize = n >> 2; + + let (mask, sign, basek_vec, top_mask) = normalize_consts_avx(basek); + + unsafe { + let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; + let mut aa: *const __m256i = a.as_ptr() as *const __m256i; + let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; + + if lsh == 0 { + for _ in 0..span { + let aa_256: __m256i = _mm256_loadu_si256(aa); + let cc_256: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(aa_256, mask, sign); + let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec, top_mask); + + let s: __m256i = _mm256_add_epi64(d0, cc_256); + let x1: __m256i = get_digit_avx(s, mask, sign); + let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let cout: __m256i = _mm256_add_epi64(c0, c1); + + _mm256_storeu_si256(xx, x1); + _mm256_storeu_si256(cc, cout); + + xx = xx.add(1); + aa = aa.add(1); + cc = cc.add(1); + } + } else { + use std::arch::x86_64::_mm256_set1_epi64x; + + let (mask_lsh, sign_lsh, basek_vec_lsh, top_mask_lsh) = normalize_consts_avx(basek - lsh); + + let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); + + for _ in 0..span { + let aa_256: __m256i = _mm256_loadu_si256(aa); + let cc_256: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(aa_256, mask_lsh, sign_lsh); + let c0: __m256i = get_carry_avx(aa_256, d0, basek_vec_lsh, top_mask_lsh); + + let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); + + let s: __m256i = _mm256_add_epi64(d0_lsh, cc_256); + let x1: __m256i = get_digit_avx(s, mask, sign); + let c1: __m256i = get_carry_avx(s, x1, basek_vec, top_mask); + let cout: __m256i = _mm256_add_epi64(c0, c1); + + _mm256_storeu_si256(xx, x1); + _mm256_storeu_si256(cc, cout); + + xx = xx.add(1); + aa = aa.add(1); + cc = cc.add(1); + } + } + } + + if !x.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_normalize_middle_step_ref; + + znx_normalize_middle_step_ref( + basek, + lsh, + &mut x[span << 2..], + &a[span << 2..], + &mut carry[span << 2..], + ); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2")] +pub fn znx_normalize_final_step_inplace_avx(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), carry.len()); + assert!(lsh < basek); + } + + use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; + + let n: usize = x.len(); + + let span: usize = n >> 2; + + let (mask, sign, _, _) = normalize_consts_avx(basek); + + unsafe { + let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; + let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; + + if lsh == 0 { + for _ in 0..span { + let xv: __m256i = _mm256_loadu_si256(xx); + let cv: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(xv, mask, sign); + let s: __m256i = _mm256_add_epi64(d0, cv); + let x1: __m256i = get_digit_avx(s, mask, sign); + + _mm256_storeu_si256(xx, x1); + + xx = xx.add(1); + cc = cc.add(1); + } + } else { + use std::arch::x86_64::_mm256_set1_epi64x; + + let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh); + + let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); + + for _ in 0..span { + let xv: __m256i = _mm256_loadu_si256(xx); + let cv: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(xv, mask_lsh, sign_lsh); + + let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); + + let s: __m256i = _mm256_add_epi64(d0_lsh, cv); + let x1: __m256i = get_digit_avx(s, mask, sign); + + _mm256_storeu_si256(xx, x1); + + xx = xx.add(1); + cc = cc.add(1); + } + } + } + + if !x.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_normalize_final_step_inplace_ref; + + znx_normalize_final_step_inplace_ref(basek, lsh, &mut x[span << 2..], &mut carry[span << 2..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "avx2")] +pub fn znx_normalize_final_step_avx(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), carry.len()); + assert_eq!(a.len(), carry.len()); + assert!(lsh < basek); + } + + use std::arch::x86_64::{_mm256_add_epi64, _mm256_loadu_si256, _mm256_sllv_epi64, _mm256_storeu_si256}; + + let n: usize = x.len(); + + let span: usize = n >> 2; + + let (mask, sign, _, _) = normalize_consts_avx(basek); + + unsafe { + let mut xx: *mut __m256i = x.as_mut_ptr() as *mut __m256i; + let mut aa: *mut __m256i = a.as_ptr() as *mut __m256i; + let mut cc: *mut __m256i = carry.as_ptr() as *mut __m256i; + + if lsh == 0 { + for _ in 0..span { + let av: __m256i = _mm256_loadu_si256(aa); + let cv: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(av, mask, sign); + let s: __m256i = _mm256_add_epi64(d0, cv); + let x1: __m256i = get_digit_avx(s, mask, sign); + + _mm256_storeu_si256(xx, x1); + + xx = xx.add(1); + aa = aa.add(1); + cc = cc.add(1); + } + } else { + use std::arch::x86_64::_mm256_set1_epi64x; + + let (mask_lsh, sign_lsh, _, _) = normalize_consts_avx(basek - lsh); + + let lsh_v: __m256i = _mm256_set1_epi64x(lsh as i64); + + for _ in 0..span { + let av: __m256i = _mm256_loadu_si256(aa); + let cv: __m256i = _mm256_loadu_si256(cc); + + let d0: __m256i = get_digit_avx(av, mask_lsh, sign_lsh); + let d0_lsh: __m256i = _mm256_sllv_epi64(d0, lsh_v); + + let s: __m256i = _mm256_add_epi64(d0_lsh, cv); + let x1: __m256i = get_digit_avx(s, mask, sign); + + _mm256_storeu_si256(xx, x1); + + xx = xx.add(1); + aa = aa.add(1); + cc = cc.add(1); + } + } + } + + if !x.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_normalize_final_step_ref; + + znx_normalize_final_step_ref( + basek, + lsh, + &mut x[span << 2..], + &a[span << 2..], + &mut carry[span << 2..], + ); + } +} + +#[cfg(all(test, any(target_arch = "x86_64", target_arch = "x86")))] +mod tests { + use poulpy_hal::reference::znx::{ + get_carry, get_digit, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, + znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_inplace_ref, + znx_normalize_middle_step_ref, + }; + + use super::*; + + use std::arch::x86_64::{_mm256_loadu_si256, _mm256_storeu_si256}; + + #[target_feature(enable = "avx2")] + fn test_get_digit_avx_internal() { + let basek: usize = 12; + let x: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let y0: Vec = vec![ + get_digit(basek, x[0]), + get_digit(basek, x[1]), + get_digit(basek, x[2]), + get_digit(basek, x[3]), + ]; + let mut y1: Vec = vec![0i64; 4]; + unsafe { + let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i); + let (mask, sign, _, _) = normalize_consts_avx(basek); + let digit: __m256i = get_digit_avx(x_256, mask, sign); + _mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit); + } + assert_eq!(y0, y1); + } + + #[test] + fn test_get_digit_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + test_get_digit_avx_internal(); + } + } + + #[target_feature(enable = "avx2")] + fn test_get_carry_avx_internal() { + let basek: usize = 12; + let x: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let carry: [i64; 4] = [1174467039, -144794816, -1466676977, 513122840]; + let y0: Vec = vec![ + get_carry(basek, x[0], carry[0]), + get_carry(basek, x[1], carry[1]), + get_carry(basek, x[2], carry[2]), + get_carry(basek, x[3], carry[3]), + ]; + let mut y1: Vec = vec![0i64; 4]; + unsafe { + let x_256: __m256i = _mm256_loadu_si256(x.as_ptr() as *const __m256i); + let d_256: __m256i = _mm256_loadu_si256(carry.as_ptr() as *const __m256i); + let (_, _, basek_vec, top_mask) = normalize_consts_avx(basek); + let digit: __m256i = get_carry_avx(x_256, d_256, basek_vec, top_mask); + _mm256_storeu_si256(y1.as_mut_ptr() as *mut __m256i, digit); + } + assert_eq!(y0, y1); + } + + #[test] + fn test_get_carry_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + test_get_carry_avx_internal(); + } + } + + #[target_feature(enable = "avx2")] + fn test_znx_normalize_first_step_inplace_avx_internal() { + let mut y0: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let mut y1: [i64; 4] = y0; + + let mut c0: [i64; 4] = [ + 621182201135793202, + 9000856573317006236, + 5542252755421113668, + -6036847263131690631, + ]; + let mut c1: [i64; 4] = c0; + + let basek = 12; + + znx_normalize_first_step_inplace_ref(basek, 0, &mut y0, &mut c0); + znx_normalize_first_step_inplace_avx(basek, 0, &mut y1, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + + znx_normalize_first_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0); + znx_normalize_first_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + } + + #[test] + fn test_znx_normalize_first_step_inplace_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + test_znx_normalize_first_step_inplace_avx_internal(); + } + } + + #[target_feature(enable = "avx2")] + fn test_znx_normalize_middle_step_inplace_avx_internal() { + let mut y0: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let mut y1: [i64; 4] = y0; + + let mut c0: [i64; 4] = [ + 621182201135793202, + 9000856573317006236, + 5542252755421113668, + -6036847263131690631, + ]; + let mut c1: [i64; 4] = c0; + + let basek = 12; + + znx_normalize_middle_step_inplace_ref(basek, 0, &mut y0, &mut c0); + znx_normalize_middle_step_inplace_avx(basek, 0, &mut y1, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + + znx_normalize_middle_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0); + znx_normalize_middle_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + } + + #[test] + fn test_znx_normalize_middle_step_inplace_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + test_znx_normalize_middle_step_inplace_avx_internal(); + } + } + + #[target_feature(enable = "avx2")] + fn test_znx_normalize_final_step_inplace_avx_internal() { + let mut y0: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let mut y1: [i64; 4] = y0; + + let mut c0: [i64; 4] = [ + 621182201135793202, + 9000856573317006236, + 5542252755421113668, + -6036847263131690631, + ]; + let mut c1: [i64; 4] = c0; + + let basek = 12; + + znx_normalize_final_step_inplace_ref(basek, 0, &mut y0, &mut c0); + znx_normalize_final_step_inplace_avx(basek, 0, &mut y1, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + + znx_normalize_final_step_inplace_ref(basek, basek - 1, &mut y0, &mut c0); + znx_normalize_final_step_inplace_avx(basek, basek - 1, &mut y1, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + } + + #[test] + fn test_znx_normalize_final_step_inplace_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + test_znx_normalize_final_step_inplace_avx_internal(); + } + } + + #[target_feature(enable = "avx2")] + fn test_znx_normalize_first_step_avx_internal() { + let mut y0: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let mut y1: [i64; 4] = y0; + let a: [i64; 4] = y0; + + let mut c0: [i64; 4] = [ + 621182201135793202, + 9000856573317006236, + 5542252755421113668, + -6036847263131690631, + ]; + let mut c1: [i64; 4] = c0; + + let basek = 12; + + znx_normalize_first_step_ref(basek, 0, &mut y0, &a, &mut c0); + znx_normalize_first_step_avx(basek, 0, &mut y1, &a, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + + znx_normalize_first_step_ref(basek, basek - 1, &mut y0, &a, &mut c0); + znx_normalize_first_step_avx(basek, basek - 1, &mut y1, &a, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + } + + #[test] + fn test_znx_normalize_first_step_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + test_znx_normalize_first_step_avx_internal(); + } + } + + #[target_feature(enable = "avx2")] + fn test_znx_normalize_middle_step_avx_internal() { + let mut y0: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let mut y1: [i64; 4] = y0; + let a: [i64; 4] = y0; + + let mut c0: [i64; 4] = [ + 621182201135793202, + 9000856573317006236, + 5542252755421113668, + -6036847263131690631, + ]; + let mut c1: [i64; 4] = c0; + + let basek = 12; + + znx_normalize_middle_step_ref(basek, 0, &mut y0, &a, &mut c0); + znx_normalize_middle_step_avx(basek, 0, &mut y1, &a, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + + znx_normalize_middle_step_ref(basek, basek - 1, &mut y0, &a, &mut c0); + znx_normalize_middle_step_avx(basek, basek - 1, &mut y1, &a, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + } + + #[test] + fn test_znx_normalize_middle_step_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + test_znx_normalize_middle_step_avx_internal(); + } + } + + #[target_feature(enable = "avx2")] + fn test_znx_normalize_final_step_avx_internal() { + let mut y0: [i64; 4] = [ + 7638646372408325293, + -61440197422348985, + 6835891051541717957, + -4835376105455195188, + ]; + let mut y1: [i64; 4] = y0; + let a: [i64; 4] = y0; + + let mut c0: [i64; 4] = [ + 621182201135793202, + 9000856573317006236, + 5542252755421113668, + -6036847263131690631, + ]; + let mut c1: [i64; 4] = c0; + + let basek = 12; + + znx_normalize_final_step_ref(basek, 0, &mut y0, &a, &mut c0); + znx_normalize_final_step_avx(basek, 0, &mut y1, &a, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + + znx_normalize_final_step_ref(basek, basek - 1, &mut y0, &a, &mut c0); + znx_normalize_final_step_avx(basek, basek - 1, &mut y1, &a, &mut c1); + + assert_eq!(y0, y1); + assert_eq!(c0, c1); + } + + #[test] + fn test_znx_normalize_final_step_avx() { + if !std::is_x86_feature_detected!("avx2") { + eprintln!("skipping: CPU lacks avx2"); + return; + }; + unsafe { + test_znx_normalize_final_step_avx_internal(); + } + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/sub.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/sub.rs new file mode 100644 index 0000000..a49b90d --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/sub.rs @@ -0,0 +1,113 @@ +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_sub_avx(res: &mut [i64], a: &[i64], b: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + assert_eq!(res.len(), b.len()); + } + + use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64}; + + let n: usize = res.len(); + + let span: usize = n >> 2; + + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut aa: *const __m256i = a.as_ptr() as *const __m256i; + let mut bb: *const __m256i = b.as_ptr() as *const __m256i; + + unsafe { + for _ in 0..span { + let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(bb)); + _mm256_storeu_si256(rr, sum); + rr = rr.add(1); + aa = aa.add(1); + bb = bb.add(1); + } + } + + // tail + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_sub_ref; + + znx_sub_ref(&mut res[span << 2..], &a[span << 2..], &b[span << 2..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64}; + + let n: usize = res.len(); + + let span: usize = n >> 2; + + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut aa: *const __m256i = a.as_ptr() as *const __m256i; + + unsafe { + for _ in 0..span { + let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(rr), _mm256_loadu_si256(aa)); + _mm256_storeu_si256(rr, sum); + rr = rr.add(1); + aa = aa.add(1); + } + } + + // tail + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_sub_ab_inplace_ref; + + znx_sub_ab_inplace_ref(&mut res[span << 2..], &a[span << 2..]); + } +} + +/// # Safety +/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`); +/// all inputs must have the same length and must not alias. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64}; + + let n: usize = res.len(); + + let span: usize = n >> 2; + + let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let mut aa: *const __m256i = a.as_ptr() as *const __m256i; + + unsafe { + for _ in 0..span { + let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(rr)); + _mm256_storeu_si256(rr, sum); + rr = rr.add(1); + aa = aa.add(1); + } + } + + // tail + if !res.len().is_multiple_of(4) { + use poulpy_hal::reference::znx::znx_sub_ba_inplace_ref; + + znx_sub_ba_inplace_ref(&mut res[span << 2..], &a[span << 2..]); + } +} diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/switch_ring.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/switch_ring.rs new file mode 100644 index 0000000..28fdd4c --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/switch_ring.rs @@ -0,0 +1,87 @@ +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +pub unsafe fn znx_switch_ring_avx(res: &mut [i64], a: &[i64]) { + unsafe { + use core::arch::x86_64::*; + + let (n_in, n_out) = (a.len(), res.len()); + + #[cfg(debug_assertions)] + { + assert!(n_in.is_power_of_two()); + assert!(n_in.max(n_out).is_multiple_of(n_in.min(n_out))) + } + + if n_in == n_out { + use poulpy_hal::reference::znx::znx_copy_ref; + + znx_copy_ref(res, a); + return; + } + + if n_in > n_out { + // Downsample: res[k] = a[k * gap_in], contiguous stores + let gap_in: usize = n_in / n_out; + + // index vector: [0*gap, 1*gap, 2*gap, 3*gap] * gap_in + let step: __m256i = _mm256_setr_epi64x(0, gap_in as i64, 2 * gap_in as i64, 3 * gap_in as i64); + + let span: usize = n_out >> 2; + let bump: __m256i = _mm256_set1_epi64x(4 * gap_in as i64); + + let mut res_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i; + let a_ptr: *const i64 = a.as_ptr(); + + let mut base: __m256i = _mm256_setzero_si256(); // starts at 0*gap + + for _ in 0..span { + // idx = base + step + let idx: __m256i = _mm256_add_epi64(base, step); + + // gather 4 spaced i64 (scale=8 bytes) + let v: __m256i = _mm256_i64gather_epi64(a_ptr, idx, 8); + + // store contiguously + _mm256_storeu_si256(res_4xi64, v); + + base = _mm256_add_epi64(base, bump); + res_4xi64 = res_4xi64.add(1); + } + } else { + // Upsample: res[k * gap_out] = a[k], i.e. res has holes; + + use poulpy_hal::reference::znx::znx_zero_ref; + let gap_out = n_out / n_in; + + // zero then scatter scalar stores + znx_zero_ref(res); + + let mut a_4xi64: *const __m256i = a.as_ptr() as *const __m256i; + + for i in (0..n_in).step_by(4) { + // Load contiguously 4 inputs + let v = _mm256_loadu_si256(a_4xi64); + + // extract 4 lanes (pextrq). This is still the best we can do on AVX2. + let x0: i64 = _mm256_extract_epi64(v, 0); + let x1: i64 = _mm256_extract_epi64(v, 1); + let x2: i64 = _mm256_extract_epi64(v, 2); + let x3: i64 = _mm256_extract_epi64(v, 3); + + // starting output pointer for this group + let mut p: *mut i64 = res.as_mut_ptr().add(i * gap_out); + + // four strided stores with pointer bump (avoid mul each time) + *p = x0; + p = p.add(gap_out); + *p = x1; + p = p.add(gap_out); + *p = x2; + p = p.add(gap_out); + *p = x3; + + a_4xi64 = a_4xi64.add(1) + } + } + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/mod.rs b/poulpy-backend/src/cpu_fft64_ref/mod.rs new file mode 100644 index 0000000..9f1be05 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/mod.rs @@ -0,0 +1,12 @@ +mod module; +mod reim; +mod scratch; +mod svp; +mod vec_znx; +mod vec_znx_big; +mod vec_znx_dft; +mod vmp; +mod zn; +mod znx; + +pub struct FFT64Ref {} diff --git a/poulpy-backend/src/cpu_fft64_ref/module.rs b/poulpy-backend/src/cpu_fft64_ref/module.rs new file mode 100644 index 0000000..d9815e2 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/module.rs @@ -0,0 +1,62 @@ +use std::ptr::NonNull; + +use poulpy_hal::{ + layouts::{Backend, Module}, + oep::ModuleNewImpl, + reference::fft64::reim::{ReimFFTTable, ReimIFFTTable}, +}; + +use crate::cpu_fft64_ref::FFT64Ref; + +#[repr(C)] +pub struct FFT64RefHandle { + table_fft: ReimFFTTable, + table_ifft: ReimIFFTTable, +} + +impl Backend for FFT64Ref { + type ScalarPrep = f64; + type ScalarBig = i64; + type Handle = FFT64RefHandle; + unsafe fn destroy(handle: NonNull) { + unsafe { + drop(Box::from_raw(handle.as_ptr())); + } + } + + fn layout_big_word_count() -> usize { + 1 + } + + fn layout_prep_word_count() -> usize { + 1 + } +} + +unsafe impl ModuleNewImpl for FFT64Ref { + fn new_impl(n: u64) -> Module { + let handle: FFT64RefHandle = FFT64RefHandle { + table_fft: ReimFFTTable::new(n as usize >> 1), + table_ifft: ReimIFFTTable::new(n as usize >> 1), + }; + // Leak Box to get a stable NonNull pointer + let ptr: NonNull = NonNull::from(Box::leak(Box::new(handle))); + unsafe { Module::from_nonnull(ptr, n) } + } +} + +pub trait FFT64ModuleHandle { + fn get_fft_table(&self) -> &ReimFFTTable; + fn get_ifft_table(&self) -> &ReimIFFTTable; +} + +impl FFT64ModuleHandle for Module { + fn get_fft_table(&self) -> &ReimFFTTable { + let h: &FFT64RefHandle = unsafe { &*self.ptr() }; + &h.table_fft + } + fn get_ifft_table(&self) -> &ReimIFFTTable { + let h: &FFT64RefHandle = unsafe { &*self.ptr() }; + &h.table_ifft + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/reim.rs b/poulpy-backend/src/cpu_fft64_ref/reim.rs new file mode 100644 index 0000000..411ee0a --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/reim.rs @@ -0,0 +1,175 @@ +use poulpy_hal::reference::fft64::{ + reim::{ + ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul, + ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, + ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref, reim_from_znx_i64_ref, + reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref, reim_sub_ab_inplace_ref, + reim_sub_ba_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref, reim_zero_ref, + }, + reim4::{ + Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks, + reim4_extract_1blk_from_reim_ref, reim4_save_1blk_to_reim_ref, reim4_save_2blk_to_reim_ref, + reim4_vec_mat1col_product_ref, reim4_vec_mat2cols_2ndcol_product_ref, reim4_vec_mat2cols_product_ref, + }, +}; + +use crate::FFT64Ref; + +impl ReimDFTExecute, f64> for FFT64Ref { + fn reim_dft_execute(table: &ReimFFTTable, data: &mut [f64]) { + fft_ref(table.m(), table.omg(), data); + } +} + +impl ReimDFTExecute, f64> for FFT64Ref { + fn reim_dft_execute(table: &ReimIFFTTable, data: &mut [f64]) { + ifft_ref(table.m(), table.omg(), data); + } +} + +impl ReimFromZnx for FFT64Ref { + #[inline(always)] + fn reim_from_znx(res: &mut [f64], a: &[i64]) { + reim_from_znx_i64_ref(res, a); + } +} + +impl ReimToZnx for FFT64Ref { + #[inline(always)] + fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]) { + reim_to_znx_i64_ref(res, divisor, a); + } +} + +impl ReimToZnxInplace for FFT64Ref { + #[inline(always)] + fn reim_to_znx_inplace(res: &mut [f64], divisor: f64) { + reim_to_znx_i64_inplace_ref(res, divisor); + } +} + +impl ReimAdd for FFT64Ref { + #[inline(always)] + fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]) { + reim_add_ref(res, a, b); + } +} + +impl ReimAddInplace for FFT64Ref { + #[inline(always)] + fn reim_add_inplace(res: &mut [f64], a: &[f64]) { + reim_add_inplace_ref(res, a); + } +} + +impl ReimSub for FFT64Ref { + #[inline(always)] + fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]) { + reim_sub_ref(res, a, b); + } +} + +impl ReimSubABInplace for FFT64Ref { + #[inline(always)] + fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) { + reim_sub_ab_inplace_ref(res, a); + } +} + +impl ReimSubBAInplace for FFT64Ref { + #[inline(always)] + fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) { + reim_sub_ba_inplace_ref(res, a); + } +} + +impl ReimNegate for FFT64Ref { + #[inline(always)] + fn reim_negate(res: &mut [f64], a: &[f64]) { + reim_negate_ref(res, a); + } +} + +impl ReimNegateInplace for FFT64Ref { + #[inline(always)] + fn reim_negate_inplace(res: &mut [f64]) { + reim_negate_inplace_ref(res); + } +} + +impl ReimMul for FFT64Ref { + #[inline(always)] + fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]) { + reim_mul_ref(res, a, b); + } +} + +impl ReimMulInplace for FFT64Ref { + #[inline(always)] + fn reim_mul_inplace(res: &mut [f64], a: &[f64]) { + reim_mul_inplace_ref(res, a); + } +} + +impl ReimAddMul for FFT64Ref { + #[inline(always)] + fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]) { + reim_addmul_ref(res, a, b); + } +} + +impl ReimCopy for FFT64Ref { + #[inline(always)] + fn reim_copy(res: &mut [f64], a: &[f64]) { + reim_copy_ref(res, a); + } +} + +impl ReimZero for FFT64Ref { + #[inline(always)] + fn reim_zero(res: &mut [f64]) { + reim_zero_ref(res); + } +} + +impl Reim4Extract1Blk for FFT64Ref { + #[inline(always)] + fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + reim4_extract_1blk_from_reim_ref(m, rows, blk, dst, src); + } +} + +impl Reim4Save1Blk for FFT64Ref { + #[inline(always)] + fn reim4_save_1blk(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + reim4_save_1blk_to_reim_ref::(m, blk, dst, src); + } +} + +impl Reim4Save2Blks for FFT64Ref { + #[inline(always)] + fn reim4_save_2blks(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + reim4_save_2blk_to_reim_ref::(m, blk, dst, src); + } +} + +impl Reim4Mat1ColProd for FFT64Ref { + #[inline(always)] + fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) { + reim4_vec_mat1col_product_ref(nrows, dst, u, v); + } +} + +impl Reim4Mat2ColsProd for FFT64Ref { + #[inline(always)] + fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) { + reim4_vec_mat2cols_product_ref(nrows, dst, u, v); + } +} + +impl Reim4Mat2Cols2ndColProd for FFT64Ref { + #[inline(always)] + fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) { + reim4_vec_mat2cols_2ndcol_product_ref(nrows, dst, u, v); + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/scratch.rs b/poulpy-backend/src/cpu_fft64_ref/scratch.rs new file mode 100644 index 0000000..41eae29 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/scratch.rs @@ -0,0 +1,261 @@ +use std::marker::PhantomData; + +use poulpy_hal::{ + DEFAULTALIGN, alloc_aligned, + api::ScratchFromBytes, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, + oep::{ + ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl, + TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, + TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl, + VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl, + }, +}; + +use crate::cpu_fft64_ref::FFT64Ref; + +unsafe impl ScratchOwnedAllocImpl for FFT64Ref { + fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned { + let data: Vec = alloc_aligned(size); + ScratchOwned { + data, + _phantom: PhantomData, + } + } +} + +unsafe impl ScratchOwnedBorrowImpl for FFT64Ref +where + B: ScratchFromBytesImpl, +{ + fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned) -> &mut Scratch { + Scratch::from_bytes(&mut scratch.data) + } +} + +unsafe impl ScratchFromBytesImpl for FFT64Ref { + fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch { + unsafe { &mut *(data as *mut [u8] as *mut Scratch) } + } +} + +unsafe impl ScratchAvailableImpl for FFT64Ref { + fn scratch_available_impl(scratch: &Scratch) -> usize { + let ptr: *const u8 = scratch.data.as_ptr(); + let self_len: usize = scratch.data.len(); + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + self_len.saturating_sub(aligned_offset) + } +} + +unsafe impl TakeSliceImpl for FFT64Ref +where + B: ScratchFromBytesImpl, +{ + fn take_slice_impl(scratch: &mut Scratch, len: usize) -> (&mut [T], &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::()); + + unsafe { + ( + &mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)), + Scratch::from_bytes(rem_slice), + ) + } + } +} + +unsafe impl TakeScalarZnxImpl for FFT64Ref +where + B: ScratchFromBytesImpl, +{ + fn take_scalar_znx_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols)); + ( + ScalarZnx::from_data(take_slice, n, cols), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeSvpPPolImpl for FFT64Ref +where + B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, +{ + fn take_svp_ppol_impl(scratch: &mut Scratch, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols)); + ( + SvpPPol::from_data(take_slice, n, cols), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxImpl for FFT64Ref +where + B: ScratchFromBytesImpl, +{ + fn take_vec_znx_impl(scratch: &mut Scratch, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size)); + ( + VecZnx::from_data(take_slice, n, cols, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxBigImpl for FFT64Ref +where + B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, +{ + fn take_vec_znx_big_impl( + scratch: &mut Scratch, + n: usize, + cols: usize, + size: usize, + ) -> (VecZnxBig<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vec_znx_big_alloc_bytes_impl(n, cols, size), + ); + ( + VecZnxBig::from_data(take_slice, n, cols, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxDftImpl for FFT64Ref +where + B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, +{ + fn take_vec_znx_dft_impl( + scratch: &mut Scratch, + n: usize, + cols: usize, + size: usize, + ) -> (VecZnxDft<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vec_znx_dft_alloc_bytes_impl(n, cols, size), + ); + + ( + VecZnxDft::from_data(take_slice, n, cols, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeVecZnxDftSliceImpl for FFT64Ref +where + B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, +{ + fn take_vec_znx_dft_slice_impl( + scratch: &mut Scratch, + len: usize, + n: usize, + cols: usize, + size: usize, + ) -> (Vec>, &mut Scratch) { + let mut scratch: &mut Scratch = scratch; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } +} + +unsafe impl TakeVecZnxSliceImpl for FFT64Ref +where + B: ScratchFromBytesImpl + TakeVecZnxImpl, +{ + fn take_vec_znx_slice_impl( + scratch: &mut Scratch, + len: usize, + n: usize, + cols: usize, + size: usize, + ) -> (Vec>, &mut Scratch) { + let mut scratch: &mut Scratch = scratch; + let mut slice: Vec> = Vec::with_capacity(len); + for _ in 0..len { + let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size); + scratch = new_scratch; + slice.push(znx); + } + (slice, scratch) + } +} + +unsafe impl TakeVmpPMatImpl for FFT64Ref +where + B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, +{ + fn take_vmp_pmat_impl( + scratch: &mut Scratch, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (VmpPMat<&mut [u8], B>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size), + ); + ( + VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +unsafe impl TakeMatZnxImpl for FFT64Ref +where + B: ScratchFromBytesImpl, +{ + fn take_mat_znx_impl( + scratch: &mut Scratch, + n: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + ) -> (MatZnx<&mut [u8]>, &mut Scratch) { + let (take_slice, rem_slice) = take_slice_aligned( + &mut scratch.data, + MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size), + ); + ( + MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), + Scratch::from_bytes(rem_slice), + ) + } +} + +fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { + let ptr: *mut u8 = data.as_mut_ptr(); + let self_len: usize = data.len(); + + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + let aligned_len: usize = self_len.saturating_sub(aligned_offset); + + if let Some(rem_len) = aligned_len.checked_sub(take_len) { + unsafe { + let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len); + let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + + let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); + + (take_slice, rem_slice) + } + } else { + panic!( + "Attempted to take {} from scratch with {} aligned bytes left", + take_len, aligned_len, + ); + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/svp.rs b/poulpy-backend/src/cpu_fft64_ref/svp.rs new file mode 100644 index 0000000..06dad9e --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/svp.rs @@ -0,0 +1,66 @@ +use poulpy_hal::{ + layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}, + oep::{ + SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, + SvpPrepareImpl, + }, + reference::fft64::svp::{svp_apply_dft_to_dft, svp_apply_dft_to_dft_inplace, svp_prepare}, +}; + +use crate::cpu_fft64_ref::{FFT64Ref, module::FFT64ModuleHandle}; + +unsafe impl SvpPPolFromBytesImpl for FFT64Ref { + fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { + SvpPPolOwned::from_bytes(n, cols, bytes) + } +} + +unsafe impl SvpPPolAllocImpl for FFT64Ref { + fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned { + SvpPPolOwned::alloc(n, cols) + } +} + +unsafe impl SvpPPolAllocBytesImpl for FFT64Ref { + fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { + Self::layout_prep_word_count() * n * cols * size_of::() + } +} + +unsafe impl SvpPrepareImpl for FFT64Ref { + fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: SvpPPolToMut, + A: ScalarZnxToRef, + { + svp_prepare(module.get_fft_table(), res, res_col, a, a_col); + } +} + +unsafe impl SvpApplyDftToDftImpl for FFT64Ref { + fn svp_apply_dft_to_dft_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxDftToMut, + A: SvpPPolToRef, + B: VecZnxDftToRef, + { + svp_apply_dft_to_dft(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl SvpApplyDftToDftInplaceImpl for FFT64Ref { + fn svp_apply_dft_to_dft_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + { + svp_apply_dft_to_dft_inplace(res, res_col, a, a_col); + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs new file mode 100644 index 0000000..ee213a9 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs @@ -0,0 +1,538 @@ +use poulpy_hal::{ + api::{ + TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes, + VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes, + }, + layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, + oep::{ + TakeSliceImpl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl, + VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl, + VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl, + VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, + VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, + VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, + VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, + VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + }, + reference::vec_znx::{ + vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace, + vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy, + vec_znx_fill_normal_ref, vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, + vec_znx_merge_rings, vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one, vec_znx_mul_xp_minus_one_inplace, + vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize, + vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, + vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, + vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar, + vec_znx_sub_scalar_inplace, vec_znx_switch_ring, + }, + source::Source, +}; + +use crate::cpu_fft64_ref::FFT64Ref; + +unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Ref { + fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_normalize_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxNormalizeImpl for FFT64Ref +where + Self: TakeSliceImpl + VecZnxNormalizeTmpBytesImpl, +{ + fn vec_znx_normalize_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_normalize::(basek, res, res_col, a, a_col, carry); + } +} + +unsafe impl VecZnxNormalizeInplaceImpl for FFT64Ref +where + Self: TakeSliceImpl + VecZnxNormalizeTmpBytesImpl, +{ + fn vec_znx_normalize_inplace_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_normalize_inplace::(basek, res, res_col, carry); + } +} + +unsafe impl VecZnxAddImpl for FFT64Ref { + fn vec_znx_add_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + vec_znx_add::(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxAddInplaceImpl for FFT64Ref { + fn vec_znx_add_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_add_inplace::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxAddScalarInplaceImpl for FFT64Ref { + fn vec_znx_add_scalar_inplace_impl( + _module: &Module, + res: &mut R, + res_col: usize, + res_limb: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + vec_znx_add_scalar_inplace::(res, res_col, res_limb, a, a_col); + } +} + +unsafe impl VecZnxAddScalarImpl for FFT64Ref { + fn vec_znx_add_scalar_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + b_limb: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef, + { + vec_znx_add_scalar::(res, res_col, a, a_col, b, b_col, b_limb); + } +} + +unsafe impl VecZnxSubImpl for FFT64Ref { + fn vec_znx_sub_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + vec_znx_sub::(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxSubABInplaceImpl for FFT64Ref { + fn vec_znx_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_sub_ab_inplace::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxSubBAInplaceImpl for FFT64Ref { + fn vec_znx_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_sub_ba_inplace::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxSubScalarImpl for FFT64Ref { + fn vec_znx_sub_scalar_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + b_limb: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef, + { + vec_znx_sub_scalar::(res, res_col, a, a_col, b, b_col, b_limb); + } +} + +unsafe impl VecZnxSubScalarInplaceImpl for FFT64Ref { + fn vec_znx_sub_scalar_inplace_impl( + _module: &Module, + res: &mut R, + res_col: usize, + res_limb: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + vec_znx_sub_scalar_inplace::(res, res_col, res_limb, a, a_col); + } +} + +unsafe impl VecZnxNegateImpl for FFT64Ref { + fn vec_znx_negate_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_negate::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxNegateInplaceImpl for FFT64Ref { + fn vec_znx_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + vec_znx_negate_inplace::(res, res_col); + } +} + +unsafe impl VecZnxLshTmpBytesImpl for FFT64Ref { + fn vec_znx_lsh_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_lsh_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxRshTmpBytesImpl for FFT64Ref { + fn vec_znx_rsh_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_rsh_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxLshImpl for FFT64Ref +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_lsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry); + } +} + +unsafe impl VecZnxLshInplaceImpl for FFT64Ref +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_lsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + a: &mut A, + a_col: usize, + scratch: &mut Scratch, + ) where + A: VecZnxToMut, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry); + } +} + +unsafe impl VecZnxRshImpl for FFT64Ref +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_rsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry); + } +} + +unsafe impl VecZnxRshInplaceImpl for FFT64Ref +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_rsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + a: &mut A, + a_col: usize, + scratch: &mut Scratch, + ) where + A: VecZnxToMut, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry); + } +} + +unsafe impl VecZnxRotateImpl for FFT64Ref { + fn vec_znx_rotate_impl(_module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_rotate::(p, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxRotateInplaceTmpBytesImpl for FFT64Ref +where + Scratch: TakeSlice, +{ + fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_rotate_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxRotateInplaceImpl for FFT64Ref +where + Scratch: TakeSlice, + Self: VecZnxRotateInplaceTmpBytesImpl, +{ + fn vec_znx_rotate_inplace_impl(module: &Module, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_rotate_inplace_tmp_bytes() / size_of::()); + vec_znx_rotate_inplace::(p, res, res_col, tmp); + } +} + +unsafe impl VecZnxAutomorphismImpl for FFT64Ref { + fn vec_znx_automorphism_impl(_module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_automorphism::(p, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxAutomorphismInplaceTmpBytesImpl for FFT64Ref { + fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_automorphism_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxAutomorphismInplaceImpl for FFT64Ref +where + Scratch: TakeSlice, + Self: VecZnxAutomorphismInplaceTmpBytesImpl, +{ + fn vec_znx_automorphism_inplace_impl( + module: &Module, + p: i64, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_automorphism_inplace_tmp_bytes() / size_of::()); + vec_znx_automorphism_inplace::(p, res, res_col, tmp); + } +} + +unsafe impl VecZnxMulXpMinusOneImpl for FFT64Ref { + fn vec_znx_mul_xp_minus_one_impl(_module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_mul_xp_minus_one::(p, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxMulXpMinusOneInplaceTmpBytesImpl for FFT64Ref +where + Scratch: TakeSlice, + Self: VecZnxMulXpMinusOneImpl, +{ + fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_mul_xp_minus_one_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxMulXpMinusOneInplaceImpl for FFT64Ref { + fn vec_znx_mul_xp_minus_one_inplace_impl( + module: &Module, + p: i64, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes() / size_of::()); + vec_znx_mul_xp_minus_one_inplace::(p, res, res_col, tmp); + } +} + +unsafe impl VecZnxSplitRingTmpBytesImpl for FFT64Ref { + fn vec_znx_split_ring_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_split_ring_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxSplitRingImpl for FFT64Ref +where + Module: VecZnxSplitRingTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_split_ring_impl( + module: &Module, + res: &mut [R], + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_split_ring_tmp_bytes() / size_of::()); + vec_znx_split_ring::(res, res_col, a, a_col, tmp); + } +} + +unsafe impl VecZnxMergeRingsTmpBytesImpl for FFT64Ref { + fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_merge_rings_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxMergeRingsImpl for FFT64Ref +where + Module: VecZnxMergeRingsTmpBytes, +{ + fn vec_znx_merge_rings_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &[A], + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_merge_rings_tmp_bytes() / size_of::()); + vec_znx_merge_rings::(res, res_col, a, a_col, tmp); + } +} + +unsafe impl VecZnxSwitchRingImpl for FFT64Ref +where + Self: VecZnxCopyImpl, +{ + fn vec_znx_switch_ring_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_switch_ring::(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxCopyImpl for FFT64Ref { + fn vec_znx_copy_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + vec_znx_copy::(res, res_col, a, a_col) + } +} + +unsafe impl VecZnxFillUniformImpl for FFT64Ref { + fn vec_znx_fill_uniform_impl(_module: &Module, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + where + R: VecZnxToMut, + { + vec_znx_fill_uniform_ref(basek, res, res_col, source) + } +} + +unsafe impl VecZnxFillNormalImpl for FFT64Ref { + fn vec_znx_fill_normal_impl( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut, + { + vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source); + } +} + +unsafe impl VecZnxAddNormalImpl for FFT64Ref { + fn vec_znx_add_normal_impl( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: VecZnxToMut, + { + vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source); + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs new file mode 100644 index 0000000..d5c4960 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs @@ -0,0 +1,332 @@ +use crate::cpu_fft64_ref::FFT64Ref; +use poulpy_hal::{ + api::{TakeSlice, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNormalizeTmpBytes}, + layouts::{ + Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, + ZnxInfos, ZnxView, ZnxViewMut, + }, + oep::{ + TakeSliceImpl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, + VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, + VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, + VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, + VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, + VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + }, + reference::{ + fft64::vec_znx_big::{ + vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small, + vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace, + vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize, + vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace, + vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace, + }, + znx::{znx_copy_ref, znx_zero_ref}, + }, + source::Source, +}; + +unsafe impl VecZnxBigAllocBytesImpl for FFT64Ref { + fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + Self::layout_big_word_count() * n * cols * size * size_of::() + } +} + +unsafe impl VecZnxBigAllocImpl for FFT64Ref { + fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned { + VecZnxBig::alloc(n, cols, size) + } +} + +unsafe impl VecZnxBigFromBytesImpl for FFT64Ref { + fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + VecZnxBig::from_bytes(n, cols, size, bytes) + } +} + +unsafe impl VecZnxBigFromSmallImpl for FFT64Ref { + fn vec_znx_big_from_small_impl(res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let mut res: VecZnxBig<&mut [u8], FFT64Ref> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let min_size: usize = res_size.min(a_size); + + for j in 0..min_size { + znx_copy_ref(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in min_size..res_size { + znx_zero_ref(res.at_mut(res_col, j)); + } + } +} + +unsafe impl VecZnxBigAddNormalImpl for FFT64Ref { + fn add_normal_impl>( + _module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) { + vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source); + } +} + +unsafe impl VecZnxBigAddImpl for FFT64Ref { + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + vec_znx_big_add(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigAddInplaceImpl for FFT64Ref { + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_add_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigAddSmallImpl for FFT64Ref { + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_small_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + vec_znx_big_add_small(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64Ref { + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_small_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + vec_znx_big_add_small_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigSubImpl for FFT64Ref { + /// Subtracts `a` to `b` and stores the result on `c`. + fn vec_znx_big_sub_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + vec_znx_big_sub(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigSubABInplaceImpl for FFT64Ref { + /// Subtracts `a` from `b` and stores the result on `b`. + fn vec_znx_big_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_sub_ab_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigSubBAInplaceImpl for FFT64Ref { + /// Subtracts `b` from `a` and stores the result on `b`. + fn vec_znx_big_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_sub_ba_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigSubSmallAImpl for FFT64Ref { + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_a_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef, + { + vec_znx_big_sub_small_a(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64Ref { + /// Subtracts `a` from `res` and stores the result on `res`. + fn vec_znx_big_sub_small_a_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + vec_znx_big_sub_small_a_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigSubSmallBImpl for FFT64Ref { + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_b_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + vec_znx_big_sub_small_b(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64Ref { + /// Subtracts `res` from `a` and stores the result on `res`. + fn vec_znx_big_sub_small_b_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + vec_znx_big_sub_small_b_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigNegateImpl for FFT64Ref { + fn vec_znx_big_negate_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_negate(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigNegateInplaceImpl for FFT64Ref { + fn vec_znx_big_negate_inplace_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxBigToMut, + { + vec_znx_big_negate_inplace(res, res_col); + } +} + +unsafe impl VecZnxBigNormalizeTmpBytesImpl for FFT64Ref { + fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_big_normalize_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxBigNormalizeImpl for FFT64Ref +where + Self: TakeSliceImpl, +{ + fn vec_znx_big_normalize_impl( + module: &Module, + basek: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxBigToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); + vec_znx_big_normalize(basek, res, res_col, a, a_col, carry); + } +} + +unsafe impl VecZnxBigAutomorphismImpl for FFT64Ref { + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. + fn vec_znx_big_automorphism_impl(_module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + vec_znx_big_automorphism(p, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxBigAutomorphismInplaceTmpBytesImpl for FFT64Ref { + fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_big_automorphism_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxBigAutomorphismInplaceImpl for FFT64Ref +where + Module: VecZnxBigAutomorphismInplaceTmpBytes, +{ + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. + fn vec_znx_big_automorphism_inplace_impl( + module: &Module, + p: i64, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + { + let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::()); + vec_znx_big_automorphism_inplace(p, res, res_col, tmp); + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs new file mode 100644 index 0000000..646cbca --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs @@ -0,0 +1,186 @@ +use poulpy_hal::{ + layouts::{ + Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, + VecZnxToRef, + }, + oep::{ + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, + VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, + }, + reference::fft64::vec_znx_dft::{ + vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub, + vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume, + vec_znx_idft_apply_tmpa, + }, +}; + +use crate::cpu_fft64_ref::{FFT64Ref, module::FFT64ModuleHandle}; + +unsafe impl VecZnxDftFromBytesImpl for FFT64Ref { + fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + VecZnxDft::, Self>::from_bytes(n, cols, size, bytes) + } +} + +unsafe impl VecZnxDftAllocBytesImpl for FFT64Ref { + fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { + Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() + } +} + +unsafe impl VecZnxDftAllocImpl for FFT64Ref { + fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned { + VecZnxDftOwned::alloc(n, cols, size) + } +} + +unsafe impl VecZnxIdftApplyTmpBytesImpl for FFT64Ref { + fn vec_znx_idft_apply_tmp_bytes_impl(_module: &Module) -> usize { + 0 + } +} + +unsafe impl VecZnxIdftApplyImpl for FFT64Ref { + fn vec_znx_idft_apply_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + _scratch: &mut Scratch, + ) where + R: VecZnxBigToMut, + A: VecZnxDftToRef, + { + vec_znx_idft_apply(module.get_ifft_table(), res, res_col, a, a_col); + } +} + +unsafe impl VecZnxIdftApplyTmpAImpl for FFT64Ref { + fn vec_znx_idft_apply_tmpa_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut, + { + vec_znx_idft_apply_tmpa(module.get_ifft_table(), res, res_col, a, a_col); + } +} + +unsafe impl VecZnxIdftApplyConsumeImpl for FFT64Ref { + fn vec_znx_idft_apply_consume_impl(module: &Module, res: VecZnxDft) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut, + { + vec_znx_idft_apply_consume(module.get_ifft_table(), res) + } +} + +unsafe impl VecZnxDftApplyImpl for FFT64Ref { + fn vec_znx_dft_apply_impl( + module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxToRef, + { + vec_znx_dft_apply(module.get_fft_table(), step, offset, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftAddImpl for FFT64Ref { + fn vec_znx_dft_add_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: VecZnxDftToRef, + { + vec_znx_dft_add(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxDftAddInplaceImpl for FFT64Ref { + fn vec_znx_dft_add_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_add_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftSubImpl for FFT64Ref { + fn vec_znx_dft_sub_impl( + _module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: VecZnxDftToRef, + { + vec_znx_dft_sub(res, res_col, a, a_col, b, b_col); + } +} + +unsafe impl VecZnxDftSubABInplaceImpl for FFT64Ref { + fn vec_znx_dft_sub_ab_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftSubBAInplaceImpl for FFT64Ref { + fn vec_znx_dft_sub_ba_inplace_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftCopyImpl for FFT64Ref { + fn vec_znx_dft_copy_impl( + _module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + vec_znx_dft_copy(step, offset, res, res_col, a, a_col); + } +} + +unsafe impl VecZnxDftZeroImpl for FFT64Ref { + fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R) + where + R: VecZnxDftToMut, + { + vec_znx_dft_zero(res); + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/vmp.rs b/poulpy-backend/src/cpu_fft64_ref/vmp.rs new file mode 100644 index 0000000..2286de5 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/vmp.rs @@ -0,0 +1,143 @@ +use poulpy_hal::{ + api::{TakeSlice, VmpPrepareTmpBytes}, + layouts::{ + Backend, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatOwned, + VmpPMatToMut, VmpPMatToRef, ZnxInfos, + }, + oep::{ + VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl, + VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, + }, + reference::fft64::vmp::{ + vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes, + }, +}; + +use crate::cpu_fft64_ref::{FFT64Ref, module::FFT64ModuleHandle}; + +unsafe impl VmpPMatAllocBytesImpl for FFT64Ref { + fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() + } +} + +unsafe impl VmpPMatAllocImpl for FFT64Ref { + fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { + VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size) + } +} + +unsafe impl VmpApplyDftToDftImpl for FFT64Ref +where + Scratch: TakeSlice, + FFT64Ref: VmpApplyDftToDftTmpBytesImpl, +{ + fn vmp_apply_dft_to_dft_impl(module: &Module, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, + { + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); + let pmat: VmpPMat<&[u8], Self> = pmat.to_ref(); + + let (tmp, _) = scratch.take_slice( + Self::vmp_apply_dft_to_dft_tmp_bytes_impl( + module, + res.size(), + a.size(), + pmat.rows(), + pmat.cols_in(), + pmat.cols_out(), + pmat.size(), + ) / size_of::(), + ); + vmp_apply_dft_to_dft(&mut res, &a, &pmat, tmp); + } +} + +unsafe impl VmpApplyDftToDftAddImpl for FFT64Ref +where + Scratch: TakeSlice, + FFT64Ref: VmpApplyDftToDftTmpBytesImpl, +{ + fn vmp_apply_dft_to_dft_add_impl( + module: &Module, + res: &mut R, + a: &A, + pmat: &C, + limb_offset: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, + { + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); + let pmat: VmpPMat<&[u8], Self> = pmat.to_ref(); + + let (tmp, _) = scratch.take_slice( + Self::vmp_apply_dft_to_dft_tmp_bytes_impl( + module, + res.size(), + a.size(), + pmat.rows(), + pmat.cols_in(), + pmat.cols_out(), + pmat.size(), + ) / size_of::(), + ); + vmp_apply_dft_to_dft_add(&mut res, &a, &pmat, limb_offset * pmat.cols_out(), tmp); + } +} + +unsafe impl VmpPrepareTmpBytesImpl for FFT64Ref { + fn vmp_prepare_tmp_bytes_impl(module: &Module, _rows: usize, _cols_in: usize, _cols_out: usize, _size: usize) -> usize { + vmp_prepare_tmp_bytes(module.n()) + } +} + +unsafe impl VmpPrepareImpl for FFT64Ref { + fn vmp_prepare_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: VmpPMatToMut, + A: MatZnxToRef, + { + {} + let mut res: VmpPMat<&mut [u8], Self> = res.to_mut(); + let a: MatZnx<&[u8]> = a.to_ref(); + let (tmp, _) = + scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()) / size_of::()); + vmp_prepare(module.get_fft_table(), &mut res, &a, tmp); + } +} + +unsafe impl VmpApplyDftToDftTmpBytesImpl for FFT64Ref { + fn vmp_apply_dft_to_dft_tmp_bytes_impl( + _module: &Module, + _res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + _b_cols_out: usize, + _b_size: usize, + ) -> usize { + vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in) + } +} + +unsafe impl VmpApplyDftToDftAddTmpBytesImpl for FFT64Ref { + fn vmp_apply_dft_to_dft_add_tmp_bytes_impl( + _module: &Module, + _res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + _b_cols_out: usize, + _b_size: usize, + ) -> usize { + vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in) + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/zn.rs b/poulpy-backend/src/cpu_fft64_ref/zn.rs new file mode 100644 index 0000000..995094b --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/zn.rs @@ -0,0 +1,73 @@ +use poulpy_hal::{ + api::TakeSlice, + layouts::{Scratch, ZnToMut}, + oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl}, + reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes}, + source::Source, +}; + +use crate::cpu_fft64_ref::FFT64Ref; + +unsafe impl ZnNormalizeTmpBytesImpl for FFT64Ref { + fn zn_normalize_tmp_bytes_impl(n: usize) -> usize { + zn_normalize_tmp_bytes(n) + } +} + +unsafe impl ZnNormalizeInplaceImpl for FFT64Ref +where + Self: TakeSliceImpl, +{ + fn zn_normalize_inplace_impl(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) + where + R: ZnToMut, + { + let (carry, _) = scratch.take_slice(n); + zn_normalize_inplace::(n, basek, res, res_col, carry); + } +} + +unsafe impl ZnFillUniformImpl for FFT64Ref { + fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + where + R: ZnToMut, + { + zn_fill_uniform(n, basek, res, res_col, source); + } +} + +unsafe impl ZnFillNormalImpl for FFT64Ref { + #[allow(clippy::too_many_arguments)] + fn zn_fill_normal_impl( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut, + { + zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound); + } +} + +unsafe impl ZnAddNormalImpl for FFT64Ref { + #[allow(clippy::too_many_arguments)] + fn zn_add_normal_impl( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut, + { + zn_add_normal(n, basek, res, res_col, k, source, sigma, bound); + } +} diff --git a/poulpy-backend/src/cpu_fft64_ref/znx.rs b/poulpy-backend/src/cpu_fft64_ref/znx.rs new file mode 100644 index 0000000..f248624 --- /dev/null +++ b/poulpy-backend/src/cpu_fft64_ref/znx.rs @@ -0,0 +1,152 @@ +use poulpy_hal::reference::znx::{ + ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep, + ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, + ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubABInplace, + ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref, + znx_negate_inplace_ref, znx_negate_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, + znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, + znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate, + znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref, +}; + +use crate::cpu_fft64_ref::FFT64Ref; + +impl ZnxAdd for FFT64Ref { + #[inline(always)] + fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) { + znx_add_ref(res, a, b); + } +} + +impl ZnxAddInplace for FFT64Ref { + #[inline(always)] + fn znx_add_inplace(res: &mut [i64], a: &[i64]) { + znx_add_inplace_ref(res, a); + } +} + +impl ZnxSub for FFT64Ref { + #[inline(always)] + fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) { + znx_sub_ref(res, a, b); + } +} + +impl ZnxSubABInplace for FFT64Ref { + #[inline(always)] + fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_ab_inplace_ref(res, a); + } +} + +impl ZnxSubBAInplace for FFT64Ref { + #[inline(always)] + fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_ba_inplace_ref(res, a); + } +} + +impl ZnxAutomorphism for FFT64Ref { + #[inline(always)] + fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) { + znx_automorphism_ref(p, res, a); + } +} + +impl ZnxCopy for FFT64Ref { + #[inline(always)] + fn znx_copy(res: &mut [i64], a: &[i64]) { + znx_copy_ref(res, a); + } +} + +impl ZnxNegate for FFT64Ref { + #[inline(always)] + fn znx_negate(res: &mut [i64], src: &[i64]) { + znx_negate_ref(res, src); + } +} + +impl ZnxNegateInplace for FFT64Ref { + #[inline(always)] + fn znx_negate_inplace(res: &mut [i64]) { + znx_negate_inplace_ref(res); + } +} + +impl ZnxRotate for FFT64Ref { + #[inline(always)] + fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) { + znx_rotate::(p, res, src); + } +} + +impl ZnxZero for FFT64Ref { + #[inline(always)] + fn znx_zero(res: &mut [i64]) { + znx_zero_ref(res); + } +} + +impl ZnxSwitchRing for FFT64Ref { + #[inline(always)] + fn znx_switch_ring(res: &mut [i64], a: &[i64]) { + znx_switch_ring_ref(res, a); + } +} + +impl ZnxNormalizeFinalStep for FFT64Ref { + #[inline(always)] + fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_final_step_ref(basek, lsh, x, a, carry); + } +} + +impl ZnxNormalizeFinalStepInplace for FFT64Ref { + #[inline(always)] + fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_final_step_inplace_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeFirstStep for FFT64Ref { + #[inline(always)] + fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_ref(basek, lsh, x, a, carry); + } +} + +impl ZnxNormalizeFirstStepCarryOnly for FFT64Ref { + #[inline(always)] + fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeFirstStepInplace for FFT64Ref { + #[inline(always)] + fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_first_step_inplace_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeMiddleStep for FFT64Ref { + #[inline(always)] + fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_ref(basek, lsh, x, a, carry); + } +} + +impl ZnxNormalizeMiddleStepCarryOnly for FFT64Ref { + #[inline(always)] + fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeMiddleStepInplace for FFT64Ref { + #[inline(always)] + fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/ffi/mod.rs b/poulpy-backend/src/cpu_spqlios/ffi/mod.rs index af417ec..df792fb 100644 --- a/poulpy-backend/src/cpu_spqlios/ffi/mod.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/mod.rs @@ -1,6 +1,8 @@ #[allow(non_camel_case_types)] pub mod module; #[allow(non_camel_case_types)] +pub mod reim; +#[allow(non_camel_case_types)] pub mod svp; #[allow(non_camel_case_types)] pub mod vec_znx; diff --git a/poulpy-backend/src/cpu_spqlios/ffi/reim.rs b/poulpy-backend/src/cpu_spqlios/ffi/reim.rs new file mode 100644 index 0000000..eb2fae7 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/ffi/reim.rs @@ -0,0 +1,172 @@ +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_fft_precomp { + _unused: [u8; 0], +} +pub type REIM_FFT_PRECOMP = reim_fft_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_ifft_precomp { + _unused: [u8; 0], +} +pub type REIM_IFFT_PRECOMP = reim_ifft_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_mul_precomp { + _unused: [u8; 0], +} +pub type REIM_FFTVEC_MUL_PRECOMP = reim_mul_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_addmul_precomp { + _unused: [u8; 0], +} +pub type REIM_FFTVEC_ADDMUL_PRECOMP = reim_addmul_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_from_znx32_precomp { + _unused: [u8; 0], +} +pub type REIM_FROM_ZNX32_PRECOMP = reim_from_znx32_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_from_znx64_precomp { + _unused: [u8; 0], +} +pub type REIM_FROM_ZNX64_PRECOMP = reim_from_znx64_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_from_tnx32_precomp { + _unused: [u8; 0], +} +pub type REIM_FROM_TNX32_PRECOMP = reim_from_tnx32_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_to_tnx32_precomp { + _unused: [u8; 0], +} +pub type REIM_TO_TNX32_PRECOMP = reim_to_tnx32_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_to_tnx_precomp { + _unused: [u8; 0], +} +pub type REIM_TO_TNX_PRECOMP = reim_to_tnx_precomp; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct reim_to_znx64_precomp { + _unused: [u8; 0], +} +pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp; +unsafe extern "C" { + pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_fft_precomp_get_buffer(tables: *const REIM_FFT_PRECOMP, buffer_index: u32) -> *mut f64; +} +unsafe extern "C" { + pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64; +} +unsafe extern "C" { + pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64); +} +unsafe extern "C" { + pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64); +} +unsafe extern "C" { + pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_ifft_precomp_get_buffer(tables: *const REIM_IFFT_PRECOMP, buffer_index: u32) -> *mut f64; +} +unsafe extern "C" { + pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64); +} +unsafe extern "C" { + pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_fftvec_mul(tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64); +} +unsafe extern "C" { + pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_fftvec_addmul(tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64); +} +unsafe extern "C" { + pub unsafe fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_from_znx32(tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32); +} +unsafe extern "C" { + pub unsafe fn reim_from_znx64(tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64); +} +unsafe extern "C" { + pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_from_znx64_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, a: *const i64); +} +unsafe extern "C" { + pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_from_tnx32(tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32); +} +unsafe extern "C" { + pub unsafe fn new_reim_to_tnx32_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX32_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_to_tnx32(tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void); +} +unsafe extern "C" { + pub unsafe fn new_reim_to_tnx_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64); +} +unsafe extern "C" { + pub unsafe fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64); +} +unsafe extern "C" { + pub unsafe fn new_reim_to_znx64_precomp(m: u32, divisor: f64, log2bound: u32) -> *mut REIM_TO_ZNX64_PRECOMP; +} +unsafe extern "C" { + pub unsafe fn reim_to_znx64(precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void); +} +unsafe extern "C" { + pub unsafe fn reim_to_znx64_simple(m: u32, divisor: f64, log2bound: u32, r: *mut i64, a: *const ::std::os::raw::c_void); +} +unsafe extern "C" { + pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void); +} +unsafe extern "C" { + pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void); +} +unsafe extern "C" { + pub unsafe fn reim_fftvec_mul_simple( + m: u32, + r: *mut ::std::os::raw::c_void, + a: *const ::std::os::raw::c_void, + b: *const ::std::os::raw::c_void, + ); +} +unsafe extern "C" { + pub unsafe fn reim_fftvec_addmul_simple( + m: u32, + r: *mut ::std::os::raw::c_void, + a: *const ::std::os::raw::c_void, + b: *const ::std::os::raw::c_void, + ); +} +unsafe extern "C" { + pub unsafe fn reim_from_znx32_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, x: *const i32); +} +unsafe extern "C" { + pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32); +} +unsafe extern "C" { + pub unsafe fn reim_to_tnx32_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut i32, x: *const ::std::os::raw::c_void); +} diff --git a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs index 3ca4713..6790625 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs @@ -7,10 +7,4 @@ mod vec_znx_dft; mod vmp_pmat; mod zn; -pub use module::FFT64; - -/// For external documentation -pub use vec_znx::{ - vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref, - vec_znx_switch_degree_ref, -}; +pub struct FFT64Spqlios; diff --git a/poulpy-backend/src/cpu_spqlios/fft64/module.rs b/poulpy-backend/src/cpu_spqlios/fft64/module.rs index 7bd4ff6..fbb3939 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/module.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/module.rs @@ -3,13 +3,23 @@ use std::ptr::NonNull; use poulpy_hal::{ layouts::{Backend, Module}, oep::ModuleNewImpl, + reference::znx::{ + ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, + ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, + ZnxRotate, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, + znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, + znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, + znx_switch_ring_ref, znx_zero_ref, + }, }; -use crate::cpu_spqlios::ffi::module::{MODULE, delete_module_info, new_module_info}; +use crate::cpu_spqlios::{ + FFT64Spqlios, + ffi::module::{MODULE, delete_module_info, new_module_info}, + znx::znx_rotate_i64, +}; -pub struct FFT64; - -impl Backend for FFT64 { +impl Backend for FFT64Spqlios { type ScalarPrep = f64; type ScalarBig = i64; type Handle = MODULE; @@ -26,8 +36,90 @@ impl Backend for FFT64 { } } -unsafe impl ModuleNewImpl for FFT64 { +unsafe impl ModuleNewImpl for FFT64Spqlios { fn new_impl(n: u64) -> Module { unsafe { Module::from_raw_parts(new_module_info(n, 0), n) } } } + +impl ZnxCopy for FFT64Spqlios { + fn znx_copy(res: &mut [i64], a: &[i64]) { + znx_copy_ref(res, a); + } +} + +impl ZnxZero for FFT64Spqlios { + fn znx_zero(res: &mut [i64]) { + znx_zero_ref(res); + } +} + +impl ZnxSwitchRing for FFT64Spqlios { + fn znx_switch_ring(res: &mut [i64], a: &[i64]) { + znx_switch_ring_ref(res, a); + } +} + +impl ZnxRotate for FFT64Spqlios { + fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) { + unsafe { + znx_rotate_i64(res.len() as u64, p, res.as_mut_ptr(), src.as_ptr()); + } + } +} + +impl ZnxNormalizeFinalStep for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_final_step_ref(basek, lsh, x, a, carry); + } +} + +impl ZnxNormalizeFinalStepInplace for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_final_step_inplace_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeFirstStep for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_ref(basek, lsh, x, a, carry); + } +} + +impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeFirstStepInplace for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_first_step_inplace_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeMiddleStep for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_ref(basek, lsh, x, a, carry); + } +} + +impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios { + #[inline(always)] + fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs index 43ff74e..1013df8 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/scratch.rs @@ -12,9 +12,9 @@ use poulpy_hal::{ }, }; -use crate::cpu_spqlios::FFT64; +use crate::cpu_spqlios::FFT64Spqlios; -unsafe impl ScratchOwnedAllocImpl for FFT64 { +unsafe impl ScratchOwnedAllocImpl for FFT64Spqlios { fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned { let data: Vec = alloc_aligned(size); ScratchOwned { @@ -24,7 +24,7 @@ unsafe impl ScratchOwnedAllocImpl for FFT64 { } } -unsafe impl ScratchOwnedBorrowImpl for FFT64 +unsafe impl ScratchOwnedBorrowImpl for FFT64Spqlios where B: ScratchFromBytesImpl, { @@ -33,13 +33,13 @@ where } } -unsafe impl ScratchFromBytesImpl for FFT64 { +unsafe impl ScratchFromBytesImpl for FFT64Spqlios { fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch { unsafe { &mut *(data as *mut [u8] as *mut Scratch) } } } -unsafe impl ScratchAvailableImpl for FFT64 { +unsafe impl ScratchAvailableImpl for FFT64Spqlios { fn scratch_available_impl(scratch: &Scratch) -> usize { let ptr: *const u8 = scratch.data.as_ptr(); let self_len: usize = scratch.data.len(); @@ -48,7 +48,7 @@ unsafe impl ScratchAvailableImpl for FFT64 { } } -unsafe impl TakeSliceImpl for FFT64 +unsafe impl TakeSliceImpl for FFT64Spqlios where B: ScratchFromBytesImpl, { @@ -64,7 +64,7 @@ where } } -unsafe impl TakeScalarZnxImpl for FFT64 +unsafe impl TakeScalarZnxImpl for FFT64Spqlios where B: ScratchFromBytesImpl, { @@ -77,7 +77,7 @@ where } } -unsafe impl TakeSvpPPolImpl for FFT64 +unsafe impl TakeSvpPPolImpl for FFT64Spqlios where B: SvpPPolAllocBytesImpl + ScratchFromBytesImpl, { @@ -90,7 +90,7 @@ where } } -unsafe impl TakeVecZnxImpl for FFT64 +unsafe impl TakeVecZnxImpl for FFT64Spqlios where B: ScratchFromBytesImpl, { @@ -103,7 +103,7 @@ where } } -unsafe impl TakeVecZnxBigImpl for FFT64 +unsafe impl TakeVecZnxBigImpl for FFT64Spqlios where B: VecZnxBigAllocBytesImpl + ScratchFromBytesImpl, { @@ -124,7 +124,7 @@ where } } -unsafe impl TakeVecZnxDftImpl for FFT64 +unsafe impl TakeVecZnxDftImpl for FFT64Spqlios where B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl, { @@ -146,7 +146,7 @@ where } } -unsafe impl TakeVecZnxDftSliceImpl for FFT64 +unsafe impl TakeVecZnxDftSliceImpl for FFT64Spqlios where B: VecZnxDftAllocBytesImpl + ScratchFromBytesImpl + TakeVecZnxDftImpl, { @@ -168,7 +168,7 @@ where } } -unsafe impl TakeVecZnxSliceImpl for FFT64 +unsafe impl TakeVecZnxSliceImpl for FFT64Spqlios where B: ScratchFromBytesImpl + TakeVecZnxImpl, { @@ -190,7 +190,7 @@ where } } -unsafe impl TakeVmpPMatImpl for FFT64 +unsafe impl TakeVmpPMatImpl for FFT64Spqlios where B: VmpPMatAllocBytesImpl + ScratchFromBytesImpl, { @@ -213,7 +213,7 @@ where } } -unsafe impl TakeMatZnxImpl for FFT64 +unsafe impl TakeMatZnxImpl for FFT64Spqlios where B: ScratchFromBytesImpl, { diff --git a/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs b/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs index 88b1a5b..b917400 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/svp_ppol.rs @@ -3,33 +3,36 @@ use poulpy_hal::{ Backend, Module, ScalarZnxToRef, SvpPPol, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos, ZnxView, ZnxViewMut, }, - oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl}, + oep::{ + SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, + SvpPrepareImpl, + }, }; use crate::cpu_spqlios::{ - FFT64, + FFT64Spqlios, ffi::{svp, vec_znx_dft::vec_znx_dft_t}, }; -unsafe impl SvpPPolFromBytesImpl for FFT64 { +unsafe impl SvpPPolFromBytesImpl for FFT64Spqlios { fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { SvpPPolOwned::from_bytes(n, cols, bytes) } } -unsafe impl SvpPPolAllocImpl for FFT64 { +unsafe impl SvpPPolAllocImpl for FFT64Spqlios { fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned { SvpPPolOwned::alloc(n, cols) } } -unsafe impl SvpPPolAllocBytesImpl for FFT64 { +unsafe impl SvpPPolAllocBytesImpl for FFT64Spqlios { fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize { - FFT64::layout_prep_word_count() * n * cols * size_of::() + FFT64Spqlios::layout_prep_word_count() * n * cols * size_of::() } } -unsafe impl SvpPrepareImpl for FFT64 { +unsafe impl SvpPrepareImpl for FFT64Spqlios { fn svp_prepare_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: SvpPPolToMut, @@ -45,9 +48,16 @@ unsafe impl SvpPrepareImpl for FFT64 { } } -unsafe impl SvpApplyImpl for FFT64 { - fn svp_apply_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) - where +unsafe impl SvpApplyDftToDftImpl for FFT64Spqlios { + fn svp_apply_dft_to_dft_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + ) where R: VecZnxDftToMut, A: SvpPPolToRef, B: VecZnxDftToRef, @@ -70,8 +80,8 @@ unsafe impl SvpApplyImpl for FFT64 { } } -unsafe impl SvpApplyInplaceImpl for FFT64 { - fn svp_apply_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +unsafe impl SvpApplyDftToDftInplaceImpl for FFT64Spqlios { + fn svp_apply_dft_to_dft_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: SvpPPolToRef, diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs index 2aa8f39..ff46b27 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs @@ -1,39 +1,44 @@ -use itertools::izip; -use rand_distr::Normal; - use poulpy_hal::{ - api::{ - TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxRotateInplace, VecZnxSwithcDegree, - }, + api::{TakeSlice, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxSplitRingTmpBytes}, layouts::{ - Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, - ZnxViewMut, ZnxZero, + Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, }, oep::{ - TakeSliceImpl, TakeVecZnxImpl, VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, - VecZnxAddScalarInplaceImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, - VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, - VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, - VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, - VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, - VecZnxSwithcDegreeImpl, + TakeSliceImpl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl, + VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl, + VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl, + VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, + VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, + VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, + VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, + VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + }, + reference::{ + vec_znx::{ + vec_znx_add_normal_ref, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy, vec_znx_fill_normal_ref, + vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings, + vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_rotate_inplace_tmp_bytes, + vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes, + vec_znx_switch_ring, + }, + znx::{znx_copy_ref, znx_zero_ref}, }, source::Source, }; use crate::cpu_spqlios::{ - FFT64, + FFT64Spqlios, ffi::{module::module_info_t, vec_znx, znx}, }; -unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64 { +unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Spqlios { fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t) as usize } } } -unsafe impl VecZnxNormalizeImpl for FFT64 +unsafe impl VecZnxNormalizeImpl for FFT64Spqlios where Self: TakeSliceImpl + VecZnxNormalizeTmpBytesImpl, { @@ -75,7 +80,7 @@ where } } -unsafe impl VecZnxNormalizeInplaceImpl for FFT64 +unsafe impl VecZnxNormalizeInplaceImpl for FFT64Spqlios where Self: TakeSliceImpl + VecZnxNormalizeTmpBytesImpl, { @@ -108,7 +113,7 @@ where } } -unsafe impl VecZnxAddImpl for FFT64 { +unsafe impl VecZnxAddImpl for FFT64Spqlios { fn vec_znx_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) where R: VecZnxToMut, @@ -142,7 +147,7 @@ unsafe impl VecZnxAddImpl for FFT64 { } } -unsafe impl VecZnxAddInplaceImpl for FFT64 { +unsafe impl VecZnxAddInplaceImpl for FFT64Spqlios { fn vec_znx_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, @@ -172,7 +177,7 @@ unsafe impl VecZnxAddInplaceImpl for FFT64 { } } -unsafe impl VecZnxAddScalarInplaceImpl for FFT64 { +unsafe impl VecZnxAddScalarInplaceImpl for FFT64Spqlios { fn vec_znx_add_scalar_inplace_impl( module: &Module, res: &mut R, @@ -209,7 +214,60 @@ unsafe impl VecZnxAddScalarInplaceImpl for FFT64 { } } -unsafe impl VecZnxSubImpl for FFT64 { +unsafe impl VecZnxAddScalarImpl for FFT64Spqlios { + fn vec_znx_add_scalar_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + b_limb: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: ScalarZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let min_size: usize = b.size().min(res.size()); + + unsafe { + vec_znx::vec_znx_add( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, b_limb), + 1_u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, b_limb), + 1_u64, + b.sl() as u64, + ); + + for j in 0..min_size { + if j != b_limb { + znx_copy_ref(res.at_mut(res_col, j), b.at(b_col, j)); + } + } + + for j in min_size..res.size() { + znx_zero_ref(res.at_mut(res_col, j)); + } + } + } +} + +unsafe impl VecZnxSubImpl for FFT64Spqlios { fn vec_znx_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) where R: VecZnxToMut, @@ -243,7 +301,7 @@ unsafe impl VecZnxSubImpl for FFT64 { } } -unsafe impl VecZnxSubABInplaceImpl for FFT64 { +unsafe impl VecZnxSubABInplaceImpl for FFT64Spqlios { fn vec_znx_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, @@ -272,7 +330,7 @@ unsafe impl VecZnxSubABInplaceImpl for FFT64 { } } -unsafe impl VecZnxSubBAInplaceImpl for FFT64 { +unsafe impl VecZnxSubBAInplaceImpl for FFT64Spqlios { fn vec_znx_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, @@ -301,7 +359,60 @@ unsafe impl VecZnxSubBAInplaceImpl for FFT64 { } } -unsafe impl VecZnxSubScalarInplaceImpl for FFT64 { +unsafe impl VecZnxSubScalarImpl for FFT64Spqlios { + fn vec_znx_sub_scalar_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + b_limb: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: ScalarZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let min_size: usize = b.size().min(res.size()); + + unsafe { + vec_znx::vec_znx_sub( + module.ptr() as *const module_info_t, + res.at_mut_ptr(res_col, b_limb), + 1_u64, + res.sl() as u64, + b.at_ptr(b_col, b_limb), + 1_u64, + b.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ); + + for j in 0..min_size { + if j != b_limb { + res.at_mut(res_col, j).copy_from_slice(b.at(b_col, j)) + } + } + + for j in min_size..res.size() { + znx_zero_ref(res.at_mut(res_col, j)); + } + } + } +} + +unsafe impl VecZnxSubScalarInplaceImpl for FFT64Spqlios { fn vec_znx_sub_scalar_inplace_impl( module: &Module, res: &mut R, @@ -327,18 +438,18 @@ unsafe impl VecZnxSubScalarInplaceImpl for FFT64 { res.at_mut_ptr(res_col, res_limb), 1_u64, res.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, res.at_ptr(res_col, res_limb), 1_u64, res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, ) } } } -unsafe impl VecZnxNegateImpl for FFT64 { +unsafe impl VecZnxNegateImpl for FFT64Spqlios { fn vec_znx_negate_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, @@ -364,7 +475,7 @@ unsafe impl VecZnxNegateImpl for FFT64 { } } -unsafe impl VecZnxNegateInplaceImpl for FFT64 { +unsafe impl VecZnxNegateInplaceImpl for FFT64Spqlios { fn vec_znx_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) where A: VecZnxToMut, @@ -384,92 +495,105 @@ unsafe impl VecZnxNegateInplaceImpl for FFT64 { } } -unsafe impl VecZnxLshInplaceImpl for FFT64 { - fn vec_znx_lsh_inplace_impl(_module: &Module, basek: usize, k: usize, a: &mut A) - where +unsafe impl VecZnxLshTmpBytesImpl for FFT64Spqlios { + fn vec_znx_lsh_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_lsh_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxRshTmpBytesImpl for FFT64Spqlios { + fn vec_znx_rsh_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_rsh_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxLshImpl for FFT64Spqlios +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_lsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_lsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry) + } +} + +unsafe impl VecZnxLshInplaceImpl for FFT64Spqlios +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_lsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + a: &mut A, + a_col: usize, + scratch: &mut Scratch, + ) where A: VecZnxToMut, { - vec_znx_lsh_inplace_ref(basek, k, a) + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_lsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry) } } -pub fn vec_znx_lsh_inplace_ref(basek: usize, k: usize, a: &mut A) +unsafe impl VecZnxRshImpl for FFT64Spqlios where - A: VecZnxToMut, + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, { - let mut a: VecZnx<&mut [u8]> = a.to_mut(); - - let n: usize = a.n(); - let cols: usize = a.cols(); - let size: usize = a.size(); - let steps: usize = k / basek; - - a.raw_mut().rotate_left(n * steps * cols); - (0..cols).for_each(|i| { - (size - steps..size).for_each(|j| { - a.zero_at(i, j); - }) - }); - - let k_rem: usize = k % basek; - - if k_rem != 0 { - let shift: usize = i64::BITS as usize - k_rem; - (0..cols).for_each(|i| { - (0..steps).for_each(|j| { - a.at_mut(i, j).iter_mut().for_each(|xi| { - *xi <<= shift; - }); - }); - }); + fn vec_znx_rsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_rsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry) } } -unsafe impl VecZnxRshInplaceImpl for FFT64 { - fn vec_znx_rsh_inplace_impl(_module: &Module, basek: usize, k: usize, a: &mut A) - where +unsafe impl VecZnxRshInplaceImpl for FFT64Spqlios +where + Module: VecZnxNormalizeTmpBytes, + Scratch: TakeSlice, +{ + fn vec_znx_rsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + a: &mut A, + a_col: usize, + scratch: &mut Scratch, + ) where A: VecZnxToMut, { - vec_znx_rsh_inplace_ref(basek, k, a) + let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); + vec_znx_rsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry) } } -pub fn vec_znx_rsh_inplace_ref(basek: usize, k: usize, a: &mut A) -where - A: VecZnxToMut, -{ - let mut a: VecZnx<&mut [u8]> = a.to_mut(); - let n: usize = a.n(); - let cols: usize = a.cols(); - let size: usize = a.size(); - let steps: usize = k / basek; - - a.raw_mut().rotate_right(n * steps * cols); - (0..cols).for_each(|i| { - (0..steps).for_each(|j| { - a.zero_at(i, j); - }) - }); - - let k_rem: usize = k % basek; - - if k_rem != 0 { - let mut carry: Vec = vec![0i64; n]; // ALLOC (but small so OK) - let shift: usize = i64::BITS as usize - k_rem; - (0..cols).for_each(|i| { - carry.fill(0); - (steps..size).for_each(|j| { - izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| { - *xi += *ci << basek; - *ci = (*xi << shift) >> shift; - *xi = (*xi - *ci) >> k_rem; - }); - }); - }) - } -} - -unsafe impl VecZnxRotateImpl for FFT64 { +unsafe impl VecZnxRotateImpl for FFT64Spqlios { fn vec_znx_rotate_impl(_module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, @@ -482,7 +606,8 @@ unsafe impl VecZnxRotateImpl for FFT64 { assert_eq!(res.n(), a.n()); } unsafe { - (0..a.size()).for_each(|j| { + let min_size = res.size().min(a.size()); + (0..min_size).for_each(|j| { znx::znx_rotate_i64( a.n() as u64, k, @@ -490,12 +615,28 @@ unsafe impl VecZnxRotateImpl for FFT64 { a.at_ptr(a_col, j), ); }); + + (min_size..res.size()).for_each(|j| { + znx_zero_ref(res.at_mut(res_col, j)); + }) } } } -unsafe impl VecZnxRotateInplaceImpl for FFT64 { - fn vec_znx_rotate_inplace_impl(_module: &Module, k: i64, a: &mut A, a_col: usize) +unsafe impl VecZnxRotateInplaceTmpBytesImpl for FFT64Spqlios +where + Scratch: TakeSlice, +{ + fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_rotate_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxRotateInplaceImpl for FFT64Spqlios +where + Scratch: TakeSlice, +{ + fn vec_znx_rotate_inplace_impl(_module: &Module, k: i64, a: &mut A, a_col: usize, _scratch: &mut Scratch) where A: VecZnxToMut, { @@ -508,7 +649,7 @@ unsafe impl VecZnxRotateInplaceImpl for FFT64 { } } -unsafe impl VecZnxAutomorphismImpl for FFT64 { +unsafe impl VecZnxAutomorphismImpl for FFT64Spqlios { fn vec_znx_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, @@ -535,8 +676,14 @@ unsafe impl VecZnxAutomorphismImpl for FFT64 { } } -unsafe impl VecZnxAutomorphismInplaceImpl for FFT64 { - fn vec_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) +unsafe impl VecZnxAutomorphismInplaceTmpBytesImpl for FFT64Spqlios { + fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_automorphism_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxAutomorphismInplaceImpl for FFT64Spqlios { + fn vec_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize, _scratch: &mut Scratch) where A: VecZnxToMut, { @@ -564,7 +711,7 @@ unsafe impl VecZnxAutomorphismInplaceImpl for FFT64 { } } -unsafe impl VecZnxMulXpMinusOneImpl for FFT64 { +unsafe impl VecZnxMulXpMinusOneImpl for FFT64Spqlios { fn vec_znx_mul_xp_minus_one_impl(module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, @@ -592,9 +739,20 @@ unsafe impl VecZnxMulXpMinusOneImpl for FFT64 { } } -unsafe impl VecZnxMulXpMinusOneInplaceImpl for FFT64 { - fn vec_znx_mul_xp_minus_one_inplace_impl(module: &Module, p: i64, res: &mut R, res_col: usize) - where +unsafe impl VecZnxMulXpMinusOneInplaceTmpBytesImpl for FFT64Spqlios { + fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_mul_xp_minus_one_inplace_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxMulXpMinusOneInplaceImpl for FFT64Spqlios { + fn vec_znx_mul_xp_minus_one_inplace_impl( + module: &Module, + p: i64, + res: &mut R, + res_col: usize, + _scratch: &mut Scratch, + ) where R: VecZnxToMut, { let mut res: VecZnx<&mut [u8]> = res.to_mut(); @@ -617,15 +775,18 @@ unsafe impl VecZnxMulXpMinusOneInplaceImpl for FFT64 { } } -unsafe impl VecZnxSplitImpl for FFT64 +unsafe impl VecZnxSplitRingTmpBytesImpl for FFT64Spqlios { + fn vec_znx_split_ring_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_split_ring_tmp_bytes(module.n()) + } +} + +unsafe impl VecZnxSplitRingImpl for FFT64Spqlios where - Self: TakeVecZnxImpl - + TakeVecZnxImpl - + VecZnxSwithcDegreeImpl - + VecZnxRotateImpl - + VecZnxRotateInplaceImpl, + Module: VecZnxSplitRingTmpBytes, + Scratch: TakeSlice, { - fn vec_znx_split_impl( + fn vec_znx_split_ring_impl( module: &Module, res: &mut [R], res_col: usize, @@ -636,287 +797,72 @@ where R: VecZnxToMut, A: VecZnxToRef, { - vec_znx_split_ref(module, res, res_col, a, a_col, scratch) + let (tmp, _) = scratch.take_slice(module.vec_znx_split_ring_tmp_bytes() / size_of::()); + vec_znx_split_ring::<_, _, FFT64Spqlios>(res, res_col, a, a_col, tmp); } } -pub fn vec_znx_split_ref( - module: &Module, - res: &mut [R], - res_col: usize, - a: &A, - a_col: usize, - scratch: &mut Scratch, -) where - B: Backend + TakeVecZnxImpl + VecZnxSwithcDegreeImpl + VecZnxRotateImpl + VecZnxRotateInplaceImpl, - R: VecZnxToMut, - A: VecZnxToRef, -{ - let a: VecZnx<&[u8]> = a.to_ref(); - - let (n_in, n_out) = (a.n(), res[0].to_mut().n()); - - let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - res[1..].iter_mut().for_each(|bi| { - debug_assert_eq!( - bi.to_mut().n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) - }); - - res.iter_mut().enumerate().for_each(|(i, bi)| { - if i == 0 { - module.vec_znx_switch_degree(bi, res_col, &a, a_col); - module.vec_znx_rotate(-1, &mut buf, 0, &a, a_col); - } else { - module.vec_znx_switch_degree(bi, res_col, &buf, a_col); - module.vec_znx_rotate_inplace(-1, &mut buf, a_col); - } - }) +unsafe impl VecZnxMergeRingsTmpBytesImpl for FFT64Spqlios { + fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module) -> usize { + vec_znx_merge_rings_tmp_bytes(module.n()) + } } -unsafe impl VecZnxMergeImpl for FFT64 +unsafe impl VecZnxMergeRingsImpl for FFT64Spqlios where - Self: VecZnxSwithcDegreeImpl + VecZnxRotateInplaceImpl, + Module: VecZnxMergeRingsTmpBytes, { - fn vec_znx_merge_impl(module: &Module, res: &mut R, res_col: usize, a: &[A], a_col: usize) - where + fn vec_znx_merge_rings_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &[A], + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxToMut, A: VecZnxToRef, { - vec_znx_merge_ref(module, res, res_col, a, a_col) + let (tmp, _) = scratch.take_slice(module.vec_znx_merge_rings_tmp_bytes() / size_of::()); + vec_znx_merge_rings::<_, _, FFT64Spqlios>(res, res_col, a, a_col, tmp); } } -pub fn vec_znx_merge_ref(module: &Module, res: &mut R, res_col: usize, a: &[A], a_col: usize) -where - B: Backend + VecZnxSwithcDegreeImpl + VecZnxRotateInplaceImpl, - R: VecZnxToMut, - A: VecZnxToRef, -{ - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - let (n_in, n_out) = (res.n(), a[0].to_ref().n()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - a[1..].iter().for_each(|ai| { - debug_assert_eq!( - ai.to_ref().n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) - }); - - a.iter().for_each(|ai| { - module.vec_znx_switch_degree(&mut res, res_col, ai, a_col); - module.vec_znx_rotate_inplace(-1, &mut res, res_col); - }); - - module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col); -} - -unsafe impl VecZnxSwithcDegreeImpl for FFT64 +unsafe impl VecZnxSwitchRingImpl for FFT64Spqlios where Self: VecZnxCopyImpl, { - fn vec_znx_switch_degree_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_switch_ring_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, { - vec_znx_switch_degree_ref(module, res, res_col, a, a_col) + vec_znx_switch_ring::<_, _, FFT64Spqlios>(res, res_col, a, a_col); } } -pub fn vec_znx_switch_degree_ref(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) -where - B: Backend + VecZnxCopyImpl, - R: VecZnxToMut, - A: VecZnxToRef, -{ - let a: VecZnx<&[u8]> = a.to_ref(); - let mut res: VecZnx<&mut [u8]> = res.to_mut(); - - let (n_in, n_out) = (a.n(), res.n()); - - if n_in == n_out { - module.vec_znx_copy(&mut res, res_col, &a, a_col); - return; - } - - let (gap_in, gap_out): (usize, usize); - if n_in > n_out { - (gap_in, gap_out) = (n_in / n_out, 1) - } else { - (gap_in, gap_out) = (1, n_out / n_in); - res.zero(); - } - - let size: usize = a.size().min(res.size()); - - (0..size).for_each(|i| { - izip!( - a.at(a_col, i).iter().step_by(gap_in), - res.at_mut(res_col, i).iter_mut().step_by(gap_out) - ) - .for_each(|(x_in, x_out)| *x_out = *x_in); - }); -} - -unsafe impl VecZnxCopyImpl for FFT64 { +unsafe impl VecZnxCopyImpl for FFT64Spqlios { fn vec_znx_copy_impl(_module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, { - vec_znx_copy_ref(res, res_col, a, a_col) + vec_znx_copy::<_, _, FFT64Spqlios>(res, res_col, a, a_col) } } -pub fn vec_znx_copy_ref(res: &mut R, res_col: usize, a: &A, a_col: usize) -where - R: VecZnxToMut, - A: VecZnxToRef, -{ - let mut res_mut: VecZnx<&mut [u8]> = res.to_mut(); - let a_ref: VecZnx<&[u8]> = a.to_ref(); - - let min_size: usize = res_mut.size().min(a_ref.size()); - - (0..min_size).for_each(|j| { - res_mut - .at_mut(res_col, j) - .copy_from_slice(a_ref.at(a_col, j)); - }); - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) -} - -unsafe impl VecZnxFillUniformImpl for FFT64 { - fn vec_znx_fill_uniform_impl( - _module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - ) where +unsafe impl VecZnxFillUniformImpl for FFT64Spqlios { + fn vec_znx_fill_uniform_impl(_module: &Module, basek: usize, res: &mut R, res_col: usize, source: &mut Source) + where R: VecZnxToMut, { - let mut a: VecZnx<&mut [u8]> = res.to_mut(); - let base2k: u64 = 1 << basek; - let mask: u64 = base2k - 1; - let base2k_half: i64 = (base2k >> 1) as i64; - (0..k.div_ceil(basek)).for_each(|j| { - a.at_mut(res_col, j) - .iter_mut() - .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); - }) + vec_znx_fill_uniform_ref(basek, res, res_col, source) } } -unsafe impl VecZnxFillDistF64Impl for FFT64 { - fn vec_znx_fill_dist_f64_impl>( - _module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: VecZnxToMut, - { - let mut a: VecZnx<&mut [u8]> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = k.div_ceil(basek) - 1; - let basek_rem: usize = (limb + 1) * basek - k; - - if basek_rem != 0 { - a.at_mut(res_col, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a = (dist_f64.round() as i64) << basek_rem; - }); - } else { - a.at_mut(res_col, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a = dist_f64.round() as i64 - }); - } - } -} - -unsafe impl VecZnxAddDistF64Impl for FFT64 { - fn vec_znx_add_dist_f64_impl>( - _module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: VecZnxToMut, - { - let mut a: VecZnx<&mut [u8]> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = k.div_ceil(basek) - 1; - let basek_rem: usize = (limb + 1) * basek - k; - - if basek_rem != 0 { - a.at_mut(res_col, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += (dist_f64.round() as i64) << basek_rem; - }); - } else { - a.at_mut(res_col, limb).iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += dist_f64.round() as i64 - }); - } - } -} - -unsafe impl VecZnxFillNormalImpl for FFT64 -where - Self: VecZnxFillDistF64Impl, -{ +unsafe impl VecZnxFillNormalImpl for FFT64Spqlios { fn vec_znx_fill_normal_impl( - module: &Module, + _module: &Module, basek: usize, res: &mut R, res_col: usize, @@ -927,24 +873,13 @@ where ) where R: VecZnxToMut, { - module.vec_znx_fill_dist_f64( - basek, - res, - res_col, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); + vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source); } } -unsafe impl VecZnxAddNormalImpl for FFT64 -where - Self: VecZnxAddDistF64Impl, -{ +unsafe impl VecZnxAddNormalImpl for FFT64Spqlios { fn vec_znx_add_normal_impl( - module: &Module, + _module: &Module, basek: usize, res: &mut R, res_col: usize, @@ -955,14 +890,6 @@ where ) where R: VecZnxToMut, { - module.vec_znx_add_dist_f64( - basek, - res, - res_col, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); + vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source); } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs index 5338e8d..5cf8efa 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs @@ -1,170 +1,98 @@ -use rand_distr::{Distribution, Normal}; - -use crate::cpu_spqlios::{FFT64, ffi::vec_znx}; +use crate::cpu_spqlios::{FFT64Spqlios, ffi::vec_znx}; use poulpy_hal::{ - api::{TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes}, + api::{TakeSlice, VecZnxBigNormalizeTmpBytes}, layouts::{ Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, }, oep::{ - TakeSliceImpl, VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, - VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, - VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, - VecZnxBigFromBytesImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, - VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, - VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + TakeSliceImpl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, + VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, + VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, + VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, + VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, + VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, + }, + reference::{ + vec_znx::vec_znx_add_normal_ref, + znx::{znx_copy_ref, znx_zero_ref}, }, source::Source, }; -unsafe impl VecZnxBigAllocBytesImpl for FFT64 { +unsafe impl VecZnxBigAllocBytesImpl for FFT64Spqlios { fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { Self::layout_big_word_count() * n * cols * size * size_of::() } } -unsafe impl VecZnxBigAllocImpl for FFT64 { +unsafe impl VecZnxBigAllocImpl for FFT64Spqlios { fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned { VecZnxBig::alloc(n, cols, size) } } -unsafe impl VecZnxBigFromBytesImpl for FFT64 { +unsafe impl VecZnxBigFromBytesImpl for FFT64Spqlios { fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { VecZnxBig::from_bytes(n, cols, size, bytes) } } -unsafe impl VecZnxBigAddDistF64Impl for FFT64 { - fn add_dist_f64_impl, D: Distribution>( - _module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) { +unsafe impl VecZnxBigFromSmallImpl for FFT64Spqlios { + fn vec_znx_big_from_small_impl(res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); + let a: VecZnx<&[u8]> = a.to_ref(); - let limb: usize = k.div_ceil(basek) - 1; - let basek_rem: usize = (limb + 1) * basek - k; + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()); + } - if basek_rem != 0 { - res.at_mut(res_col, limb).iter_mut().for_each(|x| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *x += (dist_f64.round() as i64) << basek_rem; - }); - } else { - res.at_mut(res_col, limb).iter_mut().for_each(|x| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *x += dist_f64.round() as i64 - }); + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let min_size: usize = res_size.min(a_size); + + for j in 0..min_size { + znx_copy_ref(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in min_size..res_size { + znx_zero_ref(res.at_mut(res_col, j)); } } } -unsafe impl VecZnxBigAddNormalImpl for FFT64 { +unsafe impl VecZnxBigAddNormalImpl for FFT64Spqlios { fn add_normal_impl>( - module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) { - module.vec_znx_big_add_dist_f64( - basek, - res, - res_col, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); - } -} - -unsafe impl VecZnxBigFillDistF64Impl for FFT64 { - fn fill_dist_f64_impl, D: Distribution>( _module: &Module, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source, - dist: D, - bound: f64, - ) { - let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = k.div_ceil(basek) - 1; - let basek_rem: usize = (limb + 1) * basek - k; - - if basek_rem != 0 { - res.at_mut(res_col, limb).iter_mut().for_each(|x| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *x = (dist_f64.round() as i64) << basek_rem; - }); - } else { - res.at_mut(res_col, limb).iter_mut().for_each(|x| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *x = dist_f64.round() as i64 - }); - } - } -} - -unsafe impl VecZnxBigFillNormalImpl for FFT64 { - fn fill_normal_impl>( - module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, sigma: f64, bound: f64, ) { - module.vec_znx_big_fill_dist_f64( - basek, - res, - res_col, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); + let res: VecZnxBig<&mut [u8], FFT64Spqlios> = res.to_mut(); + + let mut res_znx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + vec_znx_add_normal_ref(basek, &mut res_znx, res_col, k, sigma, bound, source); } } -unsafe impl VecZnxBigAddImpl for FFT64 { +unsafe impl VecZnxBigAddImpl for FFT64Spqlios { /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where @@ -199,7 +127,7 @@ unsafe impl VecZnxBigAddImpl for FFT64 { } } -unsafe impl VecZnxBigAddInplaceImpl for FFT64 { +unsafe impl VecZnxBigAddInplaceImpl for FFT64Spqlios { /// Adds `a` to `b` and stores the result on `b`. fn vec_znx_big_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -230,7 +158,7 @@ unsafe impl VecZnxBigAddInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigAddSmallImpl for FFT64 { +unsafe impl VecZnxBigAddSmallImpl for FFT64Spqlios { /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add_small_impl( module: &Module, @@ -272,7 +200,7 @@ unsafe impl VecZnxBigAddSmallImpl for FFT64 { } } -unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64 { +unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64Spqlios { /// Adds `a` to `b` and stores the result on `b`. fn vec_znx_big_add_small_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -303,7 +231,7 @@ unsafe impl VecZnxBigAddSmallInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigSubImpl for FFT64 { +unsafe impl VecZnxBigSubImpl for FFT64Spqlios { /// Subtracts `a` to `b` and stores the result on `c`. fn vec_znx_big_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where @@ -338,7 +266,7 @@ unsafe impl VecZnxBigSubImpl for FFT64 { } } -unsafe impl VecZnxBigSubABInplaceImpl for FFT64 { +unsafe impl VecZnxBigSubABInplaceImpl for FFT64Spqlios { /// Subtracts `a` from `b` and stores the result on `b`. fn vec_znx_big_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -369,7 +297,7 @@ unsafe impl VecZnxBigSubABInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigSubBAInplaceImpl for FFT64 { +unsafe impl VecZnxBigSubBAInplaceImpl for FFT64Spqlios { /// Subtracts `b` from `a` and stores the result on `b`. fn vec_znx_big_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -400,7 +328,7 @@ unsafe impl VecZnxBigSubBAInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigSubSmallAImpl for FFT64 { +unsafe impl VecZnxBigSubSmallAImpl for FFT64Spqlios { /// Subtracts `b` from `a` and stores the result on `c`. fn vec_znx_big_sub_small_a_impl( module: &Module, @@ -442,7 +370,7 @@ unsafe impl VecZnxBigSubSmallAImpl for FFT64 { } } -unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64 { +unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64Spqlios { /// Subtracts `a` from `res` and stores the result on `res`. fn vec_znx_big_sub_small_a_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -473,7 +401,7 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigSubSmallBImpl for FFT64 { +unsafe impl VecZnxBigSubSmallBImpl for FFT64Spqlios { /// Subtracts `b` from `a` and stores the result on `c`. fn vec_znx_big_sub_small_b_impl( module: &Module, @@ -515,7 +443,7 @@ unsafe impl VecZnxBigSubSmallBImpl for FFT64 { } } -unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64 { +unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64Spqlios { /// Subtracts `res` from `a` and stores the result on `res`. fn vec_znx_big_sub_small_b_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -546,7 +474,29 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigNegateInplaceImpl for FFT64 { +unsafe impl VecZnxBigNegateImpl for FFT64Spqlios { + fn vec_znx_big_negate_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); + let a: VecZnxBig<&[u8], Self> = a.to_ref(); + unsafe { + vec_znx::vec_znx_negate( + module.ptr(), + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +unsafe impl VecZnxBigNegateInplaceImpl for FFT64Spqlios { fn vec_znx_big_negate_inplace_impl(module: &Module, a: &mut A, a_col: usize) where A: VecZnxBigToMut, @@ -566,13 +516,13 @@ unsafe impl VecZnxBigNegateInplaceImpl for FFT64 { } } -unsafe impl VecZnxBigNormalizeTmpBytesImpl for FFT64 { +unsafe impl VecZnxBigNormalizeTmpBytesImpl for FFT64Spqlios { fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module) -> usize { unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize } } } -unsafe impl VecZnxBigNormalizeImpl for FFT64 +unsafe impl VecZnxBigNormalizeImpl for FFT64Spqlios where Self: TakeSliceImpl, { @@ -613,7 +563,7 @@ where } } -unsafe impl VecZnxBigAutomorphismImpl for FFT64 { +unsafe impl VecZnxBigAutomorphismImpl for FFT64Spqlios { /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. fn vec_znx_big_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where @@ -642,10 +592,21 @@ unsafe impl VecZnxBigAutomorphismImpl for FFT64 { } } -unsafe impl VecZnxBigAutomorphismInplaceImpl for FFT64 { +unsafe impl VecZnxBigAutomorphismInplaceTmpBytesImpl for FFT64Spqlios { + fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(_module: &Module) -> usize { + 0 + } +} + +unsafe impl VecZnxBigAutomorphismInplaceImpl for FFT64Spqlios { /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) - where + fn vec_znx_big_automorphism_inplace_impl( + module: &Module, + k: i64, + a: &mut A, + a_col: usize, + _scratch: &mut Scratch, + ) where A: VecZnxBigToMut, { let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut(); diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs index 3fe9a60..8e72bf0 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs @@ -1,60 +1,73 @@ use poulpy_hal::{ - api::{TakeSlice, VecZnxIDFTTmpBytes}, + api::{TakeSlice, VecZnxIdftApplyTmpBytes}, layouts::{ Backend, Data, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, - VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, }, oep::{ - DFTImpl, IDFTConsumeImpl, IDFTImpl, IDFTTmpAImpl, VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, - VecZnxDftAllocImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, - VecZnxDftSubImpl, VecZnxDftZeroImpl, VecZnxIDFTTmpBytesImpl, + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, + VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, + }, + reference::{ + fft64::{ + reim::{ReimCopy, ReimZero, reim_copy_ref, reim_negate_inplace_ref, reim_negate_ref, reim_zero_ref}, + vec_znx_dft::vec_znx_dft_copy, + }, + znx::znx_zero_ref, }, }; use crate::cpu_spqlios::{ - FFT64, + FFT64Spqlios, ffi::{vec_znx_big, vec_znx_dft}, }; -unsafe impl VecZnxDftFromBytesImpl for FFT64 { +unsafe impl VecZnxDftFromBytesImpl for FFT64Spqlios { fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { - VecZnxDft::, FFT64>::from_bytes(n, cols, size, bytes) + VecZnxDft::, Self>::from_bytes(n, cols, size, bytes) } } -unsafe impl VecZnxDftAllocBytesImpl for FFT64 { +unsafe impl VecZnxDftAllocBytesImpl for FFT64Spqlios { fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize { - FFT64::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() + Self::layout_prep_word_count() * n * cols * size * size_of::<::ScalarPrep>() } } -unsafe impl VecZnxDftAllocImpl for FFT64 { +unsafe impl VecZnxDftAllocImpl for FFT64Spqlios { fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned { VecZnxDftOwned::alloc(n, cols, size) } } -unsafe impl VecZnxIDFTTmpBytesImpl for FFT64 { - fn vec_znx_idft_tmp_bytes_impl(module: &Module) -> usize { +unsafe impl VecZnxIdftApplyTmpBytesImpl for FFT64Spqlios { + fn vec_znx_idft_apply_tmp_bytes_impl(module: &Module) -> usize { unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize } } } -unsafe impl IDFTImpl for FFT64 { - fn idft_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where +unsafe impl VecZnxIdftApplyImpl for FFT64Spqlios { + fn vec_znx_idft_apply_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxBigToMut, A: VecZnxDftToRef, { - let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - let a: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); #[cfg(debug_assertions)] { assert_eq!(res.n(), a.n()) } - let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_idft_tmp_bytes()); + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_idft_apply_tmp_bytes()); let min_size: usize = res.size().min(a.size()); @@ -69,47 +82,43 @@ unsafe impl IDFTImpl for FFT64 { tmp_bytes.as_mut_ptr(), ) }); - (min_size..res.size()).for_each(|j| { - res.zero_at(res_col, j); - }); + (min_size..res.size()).for_each(|j| znx_zero_ref(res.at_mut(res_col, j))); } } } -unsafe impl IDFTTmpAImpl for FFT64 { - fn idft_tmp_a_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) +unsafe impl VecZnxIdftApplyTmpAImpl for FFT64Spqlios { + fn vec_znx_idft_apply_tmpa_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxDftToMut, { - let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); - let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); + let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut(); + let mut a_mut: VecZnxDft<&mut [u8], Self> = a.to_mut(); - let min_size: usize = res_mut.size().min(a_mut.size()); + let min_size: usize = res.size().min(a_mut.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_znx_idft_tmp_a( module.ptr(), - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, 1_u64, a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1_u64, ) }); - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) + (min_size..res.size()).for_each(|j| znx_zero_ref(res.at_mut(res_col, j))) } } } -unsafe impl IDFTConsumeImpl for FFT64 { - fn idft_consume_impl(module: &Module, mut a: VecZnxDft) -> VecZnxBig +unsafe impl VecZnxIdftApplyConsumeImpl for FFT64Spqlios { + fn vec_znx_idft_apply_consume_impl(module: &Module, mut a: VecZnxDft) -> VecZnxBig where - VecZnxDft: VecZnxDftToMut, + VecZnxDft: VecZnxDftToMut, { - let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); + let mut a_mut: VecZnxDft<&mut [u8], Self> = a.to_mut(); unsafe { // Rev col and rows because ZnxDft.sl() >= ZnxBig.sl() @@ -130,89 +139,129 @@ unsafe impl IDFTConsumeImpl for FFT64 { } } -unsafe impl DFTImpl for FFT64 { - fn dft_impl(module: &Module, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) - where +unsafe impl VecZnxDftApplyImpl for FFT64Spqlios { + fn vec_znx_dft_apply_impl( + module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where R: VecZnxDftToMut, A: VecZnxToRef, { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnx<&[u8]> = a.to_ref(); - let steps: usize = a_ref.size().div_ceil(step); - let min_steps: usize = res_mut.size().min(steps); + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + let steps: usize = a.size().div_ceil(step); + let min_steps: usize = res.size().min(steps); unsafe { (0..min_steps).for_each(|j| { let limb: usize = offset + j * step; - if limb < a_ref.size() { + if limb < a.size() { vec_znx_dft::vec_znx_dft( module.ptr(), - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1_u64, - a_ref.at_ptr(a_col, limb), + a.at_ptr(a_col, limb), 1_u64, - a_ref.sl() as u64, + a.sl() as u64, ) } }); - (min_steps..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }); + (min_steps..res.size()).for_each(|j| reim_zero_ref(res.at_mut(res_col, j))); } } } -unsafe impl VecZnxDftAddImpl for FFT64 { +unsafe impl VecZnxDftAddImpl for FFT64Spqlios { fn vec_znx_dft_add_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, D: VecZnxDftToRef, { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); + let b: VecZnxDft<&[u8], Self> = b.to_ref(); - let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); + let res_size: usize = res.size(); + let a_size: usize = a.size(); + let b_size: usize = b.size(); unsafe { - (0..min_size).for_each(|j| { - vec_znx_dft::vec_dft_add( - module.ptr(), - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - ); - }); + if a_size <= b_size { + let sum_size: usize = a_size.min(res_size); + let cpy_size: usize = b_size.min(res_size); + + (0..sum_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + module.ptr(), + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + + for j in sum_size..cpy_size { + reim_copy_ref(res.at_mut(res_col, j), b.at(b_col, j)); + } + + for j in cpy_size..res_size { + reim_zero_ref(res.at_mut(res_col, j)); + } + } else { + let sum_size: usize = b_size.min(res_size); + let cpy_size: usize = a_size.min(res_size); + + (0..sum_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + module.ptr(), + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + + for j in sum_size..cpy_size { + reim_copy_ref(res.at_mut(res_col, j), a.at(b_col, j)); + } + + for j in cpy_size..res_size { + reim_zero_ref(res.at_mut(res_col, j)); + } + } } - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) } } -unsafe impl VecZnxDftAddInplaceImpl for FFT64 { +unsafe impl VecZnxDftAddInplaceImpl for FFT64Spqlios { fn vec_znx_dft_add_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); - let min_size: usize = res_mut.size().min(a_ref.size()); + let min_size: usize = res.size().min(a.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_dft_add( module.ptr(), - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1, - res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + res.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, 1, ); }); @@ -220,58 +269,93 @@ unsafe impl VecZnxDftAddInplaceImpl for FFT64 { } } -unsafe impl VecZnxDftSubImpl for FFT64 { +unsafe impl VecZnxDftSubImpl for FFT64Spqlios { fn vec_znx_dft_sub_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, D: VecZnxDftToRef, { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); - - let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); + let b: VecZnxDft<&[u8], Self> = b.to_ref(); unsafe { - (0..min_size).for_each(|j| { - vec_znx_dft::vec_dft_sub( - module.ptr(), - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, - 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, - 1, - ); - }); + let res_size: usize = res.size(); + let a_size: usize = a.size(); + let b_size: usize = b.size(); + + if a_size <= b_size { + let sum_size: usize = a_size.min(res_size); + let cpy_size: usize = b_size.min(res_size); + + (0..sum_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + module.ptr(), + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + + for j in sum_size..cpy_size { + reim_negate_ref(res.at_mut(res_col, j), b.at(b_col, j)); + } + + for j in cpy_size..res_size { + reim_zero_ref(res.at_mut(res_col, j)); + } + } else { + let sum_size: usize = b_size.min(res_size); + let cpy_size: usize = a_size.min(res_size); + + (0..sum_size).for_each(|j| { + vec_znx_dft::vec_dft_sub( + module.ptr(), + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + + for j in sum_size..cpy_size { + reim_copy_ref(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in cpy_size..res_size { + reim_zero_ref(res.at_mut(res_col, j)); + } + } } - (min_size..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) } } -unsafe impl VecZnxDftSubABInplaceImpl for FFT64 { +unsafe impl VecZnxDftSubABInplaceImpl for FFT64Spqlios { fn vec_znx_dft_sub_ab_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); - let min_size: usize = res_mut.size().min(a_ref.size()); + let min_size: usize = res.size().min(a.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_dft_sub( module.ptr(), - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1, - res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + res.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, 1, ); }); @@ -279,34 +363,38 @@ unsafe impl VecZnxDftSubABInplaceImpl for FFT64 { } } -unsafe impl VecZnxDftSubBAInplaceImpl for FFT64 { +unsafe impl VecZnxDftSubBAInplaceImpl for FFT64Spqlios { fn vec_znx_dft_sub_ba_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxDftToRef, { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut(); + let a: VecZnxDft<&[u8], Self> = a.to_ref(); - let min_size: usize = res_mut.size().min(a_ref.size()); + let min_size: usize = res.size().min(a.size()); unsafe { (0..min_size).for_each(|j| { vec_znx_dft::vec_dft_sub( module.ptr(), - res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, 1, - a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, 1, - res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + res.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, 1, ); }); + + for j in min_size..res.size() { + reim_negate_inplace_ref(res.at_mut(res_col, j)); + } } } } -unsafe impl VecZnxDftCopyImpl for FFT64 { +unsafe impl VecZnxDftCopyImpl for FFT64Spqlios { fn vec_znx_dft_copy_impl( _module: &Module, step: usize, @@ -319,27 +407,25 @@ unsafe impl VecZnxDftCopyImpl for FFT64 { R: VecZnxDftToMut, A: VecZnxDftToRef, { - let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); - let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - - let steps: usize = a_ref.size().div_ceil(step); - let min_steps: usize = res_mut.size().min(steps); - - (0..min_steps).for_each(|j| { - let limb: usize = offset + j * step; - if limb < a_ref.size() { - res_mut - .at_mut(res_col, j) - .copy_from_slice(a_ref.at(a_col, limb)); - } - }); - (min_steps..res_mut.size()).for_each(|j| { - res_mut.zero_at(res_col, j); - }) + vec_znx_dft_copy(step, offset, res, res_col, a, a_col); } } -unsafe impl VecZnxDftZeroImpl for FFT64 { +impl ReimCopy for FFT64Spqlios { + #[inline(always)] + fn reim_copy(res: &mut [f64], a: &[f64]) { + reim_copy_ref(res, a); + } +} + +impl ReimZero for FFT64Spqlios { + #[inline(always)] + fn reim_zero(res: &mut [f64]) { + reim_zero_ref(res); + } +} + +unsafe impl VecZnxDftZeroImpl for FFT64Spqlios { fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R) where R: VecZnxDftToMut, diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs index e1a405f..ca64992 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs @@ -6,22 +6,22 @@ use poulpy_hal::{ }, oep::{ VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl, - VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl, + VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, }, }; use crate::cpu_spqlios::{ - FFT64, + FFT64Spqlios, ffi::{vec_znx_dft::vec_znx_dft_t, vmp}, }; -unsafe impl VmpPMatAllocBytesImpl for FFT64 { +unsafe impl VmpPMatAllocBytesImpl for FFT64Spqlios { fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - FFT64::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() + Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::() } } -unsafe impl VmpPMatFromBytesImpl for FFT64 { +unsafe impl VmpPMatFromBytesImpl for FFT64Spqlios { fn vmp_pmat_from_bytes_impl( n: usize, rows: usize, @@ -29,19 +29,19 @@ unsafe impl VmpPMatFromBytesImpl for FFT64 { cols_out: usize, size: usize, bytes: Vec, - ) -> VmpPMatOwned { + ) -> VmpPMatOwned { VmpPMatOwned::from_bytes(n, rows, cols_in, cols_out, size, bytes) } } -unsafe impl VmpPMatAllocImpl for FFT64 { - fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { +unsafe impl VmpPMatAllocImpl for FFT64Spqlios { + fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size) } } -unsafe impl VmpPrepareTmpBytesImpl for FFT64 { - fn vmp_prepare_tmp_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { +unsafe impl VmpPrepareTmpBytesImpl for FFT64Spqlios { + fn vmp_prepare_tmp_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { unsafe { vmp::vmp_prepare_tmp_bytes( module.ptr(), @@ -52,13 +52,13 @@ unsafe impl VmpPrepareTmpBytesImpl for FFT64 { } } -unsafe impl VmpPMatPrepareImpl for FFT64 { - fn vmp_prepare_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) +unsafe impl VmpPrepareImpl for FFT64Spqlios { + fn vmp_prepare_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) where - R: VmpPMatToMut, + R: VmpPMatToMut, A: MatZnxToRef, { - let mut res: VmpPMat<&mut [u8], FFT64> = res.to_mut(); + let mut res: VmpPMat<&mut [u8], Self> = res.to_mut(); let a: MatZnx<&[u8]> = a.to_ref(); #[cfg(debug_assertions)] @@ -109,9 +109,9 @@ unsafe impl VmpPMatPrepareImpl for FFT64 { } } -unsafe impl VmpApplyDftToDftTmpBytesImpl for FFT64 { +unsafe impl VmpApplyDftToDftTmpBytesImpl for FFT64Spqlios { fn vmp_apply_dft_to_dft_tmp_bytes_impl( - module: &Module, + module: &Module, res_size: usize, a_size: usize, b_rows: usize, @@ -131,12 +131,12 @@ unsafe impl VmpApplyDftToDftTmpBytesImpl for FFT64 { } } -unsafe impl VmpApplyDftToDftImpl for FFT64 { - fn vmp_apply_dft_to_dft_impl(module: &Module, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) +unsafe impl VmpApplyDftToDftImpl for FFT64Spqlios { + fn vmp_apply_dft_to_dft_impl(module: &Module, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - C: VmpPMatToRef, + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, { let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); let a: VecZnxDft<&[u8], _> = a.to_ref(); @@ -186,9 +186,9 @@ unsafe impl VmpApplyDftToDftImpl for FFT64 { } } -unsafe impl VmpApplyDftToDftAddTmpBytesImpl for FFT64 { +unsafe impl VmpApplyDftToDftAddTmpBytesImpl for FFT64Spqlios { fn vmp_apply_dft_to_dft_add_tmp_bytes_impl( - module: &Module, + module: &Module, res_size: usize, a_size: usize, b_rows: usize, @@ -208,18 +208,18 @@ unsafe impl VmpApplyDftToDftAddTmpBytesImpl for FFT64 { } } -unsafe impl VmpApplyDftToDftAddImpl for FFT64 { +unsafe impl VmpApplyDftToDftAddImpl for FFT64Spqlios { fn vmp_apply_dft_to_dft_add_impl( - module: &Module, + module: &Module, res: &mut R, a: &A, b: &C, scale: usize, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - R: VecZnxDftToMut, - A: VecZnxDftToRef, - C: VmpPMatToRef, + R: VecZnxDftToMut, + A: VecZnxDftToRef, + C: VmpPMatToRef, { let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); let a: VecZnxDft<&[u8], _> = a.to_ref(); diff --git a/poulpy-backend/src/cpu_spqlios/fft64/zn.rs b/poulpy-backend/src/cpu_spqlios/fft64/zn.rs index 80a2f03..b2d0f42 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/zn.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/zn.rs @@ -1,17 +1,14 @@ use poulpy_hal::{ api::TakeSlice, layouts::{Scratch, Zn, ZnToMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut}, - oep::{ - TakeSliceImpl, ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, - ZnNormalizeInplaceImpl, - }, + oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl}, + reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform}, source::Source, }; -use rand_distr::Normal; -use crate::cpu_spqlios::{FFT64, ffi::zn64}; +use crate::cpu_spqlios::{FFT64Spqlios, ffi::zn64}; -unsafe impl ZnNormalizeInplaceImpl for FFT64 +unsafe impl ZnNormalizeInplaceImpl for FFT64Spqlios where Self: TakeSliceImpl, { @@ -39,113 +36,17 @@ where } } -unsafe impl ZnFillUniformImpl for FFT64 { - fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) +unsafe impl ZnFillUniformImpl for FFT64Spqlios { + fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut, { - let mut a: Zn<&mut [u8]> = res.to_mut(); - let base2k: u64 = 1 << basek; - let mask: u64 = base2k - 1; - let base2k_half: i64 = (base2k >> 1) as i64; - (0..k.div_ceil(basek)).for_each(|j| { - a.at_mut(res_col, j)[..n] - .iter_mut() - .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); - }) + zn_fill_uniform(n, basek, res, res_col, source); } } -unsafe impl ZnFillDistF64Impl for FFT64 { - fn zn_fill_dist_f64_impl>( - n: usize, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: ZnToMut, - { - let mut a: Zn<&mut [u8]> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = k.div_ceil(basek) - 1; - let basek_rem: usize = (limb + 1) * basek - k; - - if basek_rem != 0 { - a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a = (dist_f64.round() as i64) << basek_rem; - }); - } else { - a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a = dist_f64.round() as i64 - }); - } - } -} - -unsafe impl ZnAddDistF64Impl for FFT64 { - fn zn_add_dist_f64_impl>( - n: usize, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: ZnToMut, - { - let mut a: Zn<&mut [u8]> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = k.div_ceil(basek) - 1; - let basek_rem: usize = (limb + 1) * basek - k; - - if basek_rem != 0 { - a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += (dist_f64.round() as i64) << basek_rem; - }); - } else { - a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| { - let mut dist_f64: f64 = dist.sample(source); - while dist_f64.abs() > bound { - dist_f64 = dist.sample(source) - } - *a += dist_f64.round() as i64 - }); - } - } -} - -unsafe impl ZnFillNormalImpl for FFT64 -where - Self: ZnFillDistF64Impl, -{ +unsafe impl ZnFillNormalImpl for FFT64Spqlios { + #[allow(clippy::too_many_arguments)] fn zn_fill_normal_impl( n: usize, basek: usize, @@ -158,23 +59,12 @@ where ) where R: ZnToMut, { - Self::zn_fill_dist_f64_impl( - n, - basek, - res, - res_col, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); + zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound); } } -unsafe impl ZnAddNormalImpl for FFT64 -where - Self: ZnAddDistF64Impl, -{ +unsafe impl ZnAddNormalImpl for FFT64Spqlios { + #[allow(clippy::too_many_arguments)] fn zn_add_normal_impl( n: usize, basek: usize, @@ -187,15 +77,6 @@ where ) where R: ZnToMut, { - Self::zn_add_dist_f64_impl( - n, - basek, - res, - res_col, - k, - source, - Normal::new(0.0, sigma).unwrap(), - bound, - ); + zn_add_normal(n, basek, res, res_col, k, source, sigma, bound); } } diff --git a/poulpy-backend/src/cpu_spqlios/mod.rs b/poulpy-backend/src/cpu_spqlios/mod.rs index 40baf00..6a34dec 100644 --- a/poulpy-backend/src/cpu_spqlios/mod.rs +++ b/poulpy-backend/src/cpu_spqlios/mod.rs @@ -3,7 +3,8 @@ mod fft64; mod ntt120; #[cfg(test)] -mod test; +mod tests; +pub use ffi::*; pub use fft64::*; pub use ntt120::*; diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic index 708e5d7..b6938df 160000 --- a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit 708e5d7e867abba60f029794eea58aa2735e1f15 +Subproject commit b6938df774d629e89e8ce3645f4c33df4b1144d1 diff --git a/poulpy-backend/src/cpu_spqlios/test/mod.rs b/poulpy-backend/src/cpu_spqlios/test/mod.rs deleted file mode 100644 index 3146d6e..0000000 --- a/poulpy-backend/src/cpu_spqlios/test/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod vec_znx_fft64; -mod vmp_pmat_fft64; diff --git a/poulpy-backend/src/cpu_spqlios/test/vec_znx_fft64.rs b/poulpy-backend/src/cpu_spqlios/test/vec_znx_fft64.rs deleted file mode 100644 index 9e378dc..0000000 --- a/poulpy-backend/src/cpu_spqlios/test/vec_znx_fft64.rs +++ /dev/null @@ -1,19 +0,0 @@ -use poulpy_hal::{ - api::ModuleNew, - layouts::Module, - tests::vec_znx::{test_vec_znx_add_normal, test_vec_znx_fill_uniform}, -}; - -use crate::cpu_spqlios::FFT64; - -#[test] -fn test_vec_znx_fill_uniform_fft64() { - let module: Module = Module::::new(1 << 12); - test_vec_znx_fill_uniform(&module); -} - -#[test] -fn test_vec_znx_add_normal_fft64() { - let module: Module = Module::::new(1 << 12); - test_vec_znx_add_normal(&module); -} diff --git a/poulpy-backend/src/cpu_spqlios/test/vmp_pmat_fft64.rs b/poulpy-backend/src/cpu_spqlios/test/vmp_pmat_fft64.rs deleted file mode 100644 index 7354d73..0000000 --- a/poulpy-backend/src/cpu_spqlios/test/vmp_pmat_fft64.rs +++ /dev/null @@ -1,8 +0,0 @@ -use poulpy_hal::tests::vmp_pmat::test_vmp_apply; - -use crate::cpu_spqlios::FFT64; - -#[test] -fn vmp_apply() { - test_vmp_apply::(); -} diff --git a/poulpy-backend/src/cpu_spqlios/tests.rs b/poulpy-backend/src/cpu_spqlios/tests.rs new file mode 100644 index 0000000..3c30b6f --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/tests.rs @@ -0,0 +1,117 @@ +use poulpy_hal::{backend_test_suite, cross_backend_test_suite}; + +cross_backend_test_suite! { + mod vec_znx, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_spqlios::FFT64Spqlios, + size = 1 << 5, + basek = 12, + tests = { + test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add, + test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace, + test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar, + test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace, + test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub, + test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace, + test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace, + test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar, + test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace, + test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh, + test_vec_znx_rsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh_inplace, + test_vec_znx_lsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh, + test_vec_znx_lsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh_inplace, + test_vec_znx_negate => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate, + test_vec_znx_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate_inplace, + test_vec_znx_rotate => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate, + test_vec_znx_rotate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate_inplace, + test_vec_znx_automorphism => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism, + test_vec_znx_automorphism_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism_inplace, + test_vec_znx_mul_xp_minus_one => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one, + test_vec_znx_mul_xp_minus_one_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one_inplace, + test_vec_znx_normalize => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize, + test_vec_znx_normalize_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize_inplace, + test_vec_znx_switch_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_switch_ring, + test_vec_znx_split_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_split_ring, + test_vec_znx_copy => poulpy_hal::test_suite::vec_znx::test_vec_znx_copy, + } +} + +cross_backend_test_suite! { + mod svp, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_spqlios::FFT64Spqlios, + size = 1 << 5, + basek = 12, + tests = { + test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft, + test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace, + } +} + +cross_backend_test_suite! { + mod vec_znx_big, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_spqlios::FFT64Spqlios, + size = 1 << 5, + basek = 12, + tests = { + test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add, + test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace, + test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small, + test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace, + test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub, + test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace, + test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism, + test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace, + test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate, + test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace, + test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize, + test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace, + test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a, + test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace, + test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b, + test_vec_znx_big_sub_small_b_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b_inplace, + } +} + +cross_backend_test_suite! { + mod vec_znx_dft, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_spqlios::FFT64Spqlios, + size = 1 << 5, + basek = 12, + tests = { + test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add, + test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace, + test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub, + test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace, + test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace, + test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply, + test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume, + test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa, + } +} + +cross_backend_test_suite! { + mod vmp, + backend_ref = crate::cpu_fft64_ref::FFT64Ref, + backend_test = crate::cpu_spqlios::FFT64Spqlios, + size = 1 << 5, + basek = 12, + tests = { + test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft, + test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add, + } +} + +backend_test_suite! { + mod sampling, + backend = crate::cpu_spqlios::FFT64Spqlios, + size = 1 << 12, + tests = { + test_vec_znx_fill_uniform => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_uniform, + test_vec_znx_fill_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_normal, + test_vec_znx_add_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_normal, + test_vec_znx_big_sub_small_b_inplace => poulpy_hal::reference::fft64::vec_znx_big::test_vec_znx_big_add_normal, + } +} diff --git a/poulpy-backend/src/lib.rs b/poulpy-backend/src/lib.rs index 15632e0..2d4771b 100644 --- a/poulpy-backend/src/lib.rs +++ b/poulpy-backend/src/lib.rs @@ -1 +1,7 @@ +pub mod cpu_fft64_avx; +pub mod cpu_fft64_ref; pub mod cpu_spqlios; + +pub use cpu_fft64_avx::FFT64Avx; +pub use cpu_fft64_ref::FFT64Ref; +pub use cpu_spqlios::FFT64Spqlios; diff --git a/poulpy-core/Cargo.toml b/poulpy-core/Cargo.toml index c17549e..47f14bc 100644 --- a/poulpy-core/Cargo.toml +++ b/poulpy-core/Cargo.toml @@ -15,6 +15,7 @@ poulpy-hal = {path="../poulpy-hal"} poulpy-backend = {path="../poulpy-backend"} itertools = {workspace = true} byteorder = {workspace = true} +once_cell = {workspace = true} [[bench]] name = "external_product_glwe_fft64" diff --git a/poulpy-core/benches/external_product_glwe_fft64.rs b/poulpy-core/benches/external_product_glwe_fft64.rs index ab3959e..b7f6814 100644 --- a/poulpy-core/benches/external_product_glwe_fft64.rs +++ b/poulpy-core/benches/external_product_glwe_fft64.rs @@ -6,7 +6,7 @@ use std::hint::black_box; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_hal::{ api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, layouts::{Module, ScalarZnx, ScratchOwned}, @@ -26,7 +26,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { } fn runner(p: Params) -> impl FnMut() { - let module: Module = Module::::new(1 << p.log_n); + let module: Module = Module::::new(1 << p.log_n); let n: usize = module.n(); let basek: usize = p.basek; @@ -43,7 +43,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct_out, rank); let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut scratch: ScratchOwned = ScratchOwned::alloc( + let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe_in.k()) | GLWECiphertext::external_product_scratch_space( @@ -63,7 +63,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: GLWESecretPrepared, FFT64> = sk.prepare_alloc(&module, scratch.borrow()); + let sk_dft: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); ct_ggsw.encrypt_sk( &module, @@ -82,7 +82,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ggsw_prepared: GGSWCiphertextPrepared, FFT64> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); + let ggsw_prepared: GGSWCiphertextPrepared, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); move || { ct_glwe_out.external_product(&module, &ct_glwe_in, &ggsw_prepared, scratch.borrow()); @@ -120,7 +120,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { } fn runner(p: Params) -> impl FnMut() { - let module: Module = Module::::new(1 << p.log_n); + let module: Module = Module::::new(1 << p.log_n); let n = module.n(); let basek: usize = p.basek; @@ -135,7 +135,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_glwe, rank); let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut scratch: ScratchOwned = ScratchOwned::alloc( + let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), @@ -147,7 +147,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: GLWESecretPrepared, FFT64> = sk.prepare_alloc(&module, scratch.borrow()); + let sk_dft: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); ct_ggsw.encrypt_sk( &module, @@ -166,7 +166,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ggsw_prepared: GGSWCiphertextPrepared, FFT64> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); + let ggsw_prepared: GGSWCiphertextPrepared, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow()); move || { let scratch_borrow = scratch.borrow(); diff --git a/poulpy-core/benches/keyswitch_glwe_fft64.rs b/poulpy-core/benches/keyswitch_glwe_fft64.rs index 37f3326..baa8860 100644 --- a/poulpy-core/benches/keyswitch_glwe_fft64.rs +++ b/poulpy-core/benches/keyswitch_glwe_fft64.rs @@ -5,7 +5,7 @@ use poulpy_core::layouts::{ use std::{hint::black_box, time::Duration}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_hal::{ api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, layouts::{Module, ScratchOwned}, @@ -27,7 +27,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { } fn runner(p: Params) -> impl FnMut() { - let module: Module = Module::::new(1 << p.log_n); + let module: Module = Module::::new(1 << p.log_n); let n = module.n(); let basek: usize = p.basek; @@ -44,7 +44,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_rlwe_in, rank_in); let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_rlwe_out, rank_out); - let mut scratch: ScratchOwned = ScratchOwned::alloc( + let mut scratch: ScratchOwned = ScratchOwned::alloc( GGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( @@ -65,7 +65,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: GLWESecretPrepared, FFT64> = sk_in.prepare_alloc(&module, scratch.borrow()); + let sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow()); let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); sk_out.fill_ternary_prob(0.5, &mut source_xs); @@ -132,7 +132,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { } fn runner(p: Params) -> impl FnMut() { - let module: Module = Module::::new(1 << p.log_n); + let module: Module = Module::::new(1 << p.log_n); let n = module.n(); let basek: usize = p.basek; @@ -146,7 +146,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut scratch: ScratchOwned = ScratchOwned::alloc( + let mut scratch: ScratchOwned = ScratchOwned::alloc( GGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), digits, rank), @@ -158,7 +158,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_dft: GLWESecretPrepared, FFT64> = sk_in.prepare_alloc(&module, scratch.borrow()); + let sk_in_dft: GLWESecretPrepared, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow()); let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); sk_out.fill_ternary_prob(0.5, &mut source_xs); @@ -180,7 +180,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { scratch.borrow(), ); - let ksk_prepared: GGLWESwitchingKeyPrepared, FFT64> = ksk.prepare_alloc(&module, scratch.borrow()); + let ksk_prepared: GGLWESwitchingKeyPrepared, FFT64Spqlios> = ksk.prepare_alloc(&module, scratch.borrow()); move || { ct.keyswitch_inplace(&module, &ksk_prepared, scratch.borrow()); diff --git a/poulpy-core/examples/encryption.rs b/poulpy-core/examples/encryption.rs index df6e944..169f6f4 100644 --- a/poulpy-core/examples/encryption.rs +++ b/poulpy-core/examples/encryption.rs @@ -1,4 +1,4 @@ -use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_core::{ GLWEOperations, SIGMA, layouts::{ @@ -31,7 +31,7 @@ fn main() { let rank: usize = 1; // Instantiate Module (DFT Tables) - let module: Module = Module::::new(n as u64); + let module: Module = Module::::new(n as u64); // Allocates ciphertext & plaintexts let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); @@ -44,7 +44,7 @@ fn main() { let mut source_xa: Source = Source::new([2u8; 32]); // Scratch space - let mut scratch: ScratchOwned = ScratchOwned::alloc( + let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()), ); @@ -54,10 +54,10 @@ fn main() { sk.fill_ternary_prob(0.5, &mut source_xs); // Backend-prepared secret - let sk_prepared: GLWESecretPrepared, FFT64> = sk.prepare_alloc(&module, scratch.borrow()); + let sk_prepared: GLWESecretPrepared, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow()); // Uniform plaintext - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_pt, &mut source_xa); + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); // Encryption ct.encrypt_sk( diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 4c3ed3f..30ae17b 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; @@ -54,12 +54,12 @@ impl GGLWEAutomorphismKey { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphism - + VecZnxAutomorphismInplace, + + VecZnxAutomorphismInplace, Scratch: ScratchAvailable + TakeVecZnxDft, { #[cfg(debug_assertions)] @@ -72,7 +72,7 @@ impl GGLWEAutomorphismKey { lhs.rank_in() ); assert_eq!( - lhs.rank_out(), + self.rank_out(), rhs.rank_in(), "ksk_in output rank: {} != ksk_apply input rank: {}", self.rank_out(), @@ -113,7 +113,7 @@ impl GGLWEAutomorphismKey { // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) (0..cols_out).for_each(|i| { - module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i); + module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); }); }); }); @@ -138,17 +138,56 @@ impl GGLWEAutomorphismKey { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxAutomorphism - + VecZnxAutomorphismInplace, + + VecZnxAutomorphismInplace, Scratch: ScratchAvailable + TakeVecZnxDft, { - unsafe { - let self_ptr: *mut GGLWEAutomorphismKey = self as *mut GGLWEAutomorphismKey; - self.automorphism(module, &*self_ptr, rhs, scratch); + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); } + + let cols_out: usize = rhs.rank_out() + 1; + + let p: i64 = self.p(); + let p_inv = module.galois_element_inv(p); + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i); + + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + (0..cols_out).for_each(|i| { + module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); + }); + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + res_ct.keyswitch_inplace(module, &rhs.key, scratch); + + // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) + (0..cols_out).for_each(|i| { + module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch); + }); + }); + }); + + self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64); } } diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index a67be34..d23bc7b 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, - VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, + VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, @@ -79,16 +79,16 @@ impl GGSWCiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace - + IDFTTmpA, + + VecZnxIdftApplyTmpA, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig, { #[cfg(debug_assertions)] @@ -133,7 +133,13 @@ impl GGSWCiphertext { ) }; - self.automorphism_internal(module, lhs, auto_key, scratch); + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + self.at_mut(row_i, 0) + .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch); + }); self.expand_row(module, tensor_key, scratch); } @@ -149,49 +155,25 @@ impl GGSWCiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace - + IDFTTmpA, + + VecZnxIdftApplyTmpA, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnxBig, - { - unsafe { - let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; - self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch); - } - } - - fn automorphism_internal( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - auto_key: &GGLWEAutomorphismKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + DFT - + IDFTConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace, - Scratch: TakeVecZnxDft + ScratchAvailable, { // Keyswitch the j-th row of the col 0 - (0..lhs.rows()).for_each(|row_i| { + (0..self.rows()).for_each(|row_i| { // Key-switch column 0, i.e. // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) self.at_mut(row_i, 0) - .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch); + .automorphism_inplace(module, auto_key, scratch); }); + self.expand_row(module, tensor_key, scratch); } } diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index dd9484b..1d8077a 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -1,8 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace, - VecZnxBigSubSmallBInplace, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace, VecZnxBigSubSmallBInplace, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig}, }; @@ -54,16 +55,16 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace, + + VecZnxAutomorphismInplace, Scratch: TakeVecZnxDft + ScratchAvailable, { self.keyswitch(module, lhs, &rhs.key, scratch); (0..self.rank() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); + module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); }) } @@ -78,16 +79,16 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace, + + VecZnxAutomorphismInplace, Scratch: TakeVecZnxDft + ScratchAvailable, { self.keyswitch_inplace(module, &rhs.key, scratch); (0..self.rank() + 1).for_each(|i| { - module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i); + module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch); }) } @@ -103,8 +104,8 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace, @@ -114,12 +115,12 @@ impl GLWECiphertext { { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } - let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1); + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); (0..self.cols()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i); + module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); }) } @@ -134,17 +135,24 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace, Scratch: TakeVecZnxDft + ScratchAvailable, { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - self.automorphism_add(module, &*self_ptr, rhs, scratch); + #[cfg(debug_assertions)] + { + self.assert_keyswitch_inplace(module, &rhs.key, scratch); } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); + (0..self.cols()).for_each(|i| { + module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); + module.vec_znx_big_add_small_inplace(&mut res_big, i, &self.data, i); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + }) } pub fn automorphism_sub_ab( @@ -159,8 +167,8 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace @@ -171,12 +179,12 @@ impl GLWECiphertext { { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } - let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1); + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); (0..self.cols()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i); + module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); }) } @@ -191,18 +199,25 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallAInplace, Scratch: TakeVecZnxDft + ScratchAvailable, { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - self.automorphism_sub_ab(module, &*self_ptr, rhs, scratch); + #[cfg(debug_assertions)] + { + self.assert_keyswitch_inplace(module, &rhs.key, scratch); } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); + (0..self.cols()).for_each(|i| { + module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); + module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &self.data, i); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + }) } pub fn automorphism_sub_ba( @@ -217,8 +232,8 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace @@ -229,12 +244,12 @@ impl GLWECiphertext { { self.assert_keyswitch(module, lhs, &rhs.key, scratch); } - let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1); + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); (0..self.cols()).for_each(|i| { - module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i); + module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i); - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); }) } @@ -249,17 +264,24 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallBInplace, Scratch: TakeVecZnxDft + ScratchAvailable, { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - self.automorphism_sub_ba(module, &*self_ptr, rhs, scratch); + #[cfg(debug_assertions)] + { + self.assert_keyswitch_inplace(module, &rhs.key, scratch); } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size + let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1); + (0..self.cols()).for_each(|i| { + module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1); + module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &self.data, i); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + }) } } diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index 244035d..037eb6f 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -1,7 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; @@ -60,8 +61,8 @@ impl LWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt, @@ -71,8 +72,8 @@ impl LWECiphertext { assert_eq!(self.basek(), a.basek()); assert_eq!(a.n(), ks.n()); } - let (mut tmp_glwe, scratch1) = scratch.take_glwe_ct(a.n(), a.basek(), self.k(), 1); - tmp_glwe.keyswitch(module, a, &ks.0, scratch1); + let (mut tmp_glwe, scratch_1) = scratch.take_glwe_ct(a.n(), a.basek(), self.k(), 1); + tmp_glwe.keyswitch(module, a, &ks.0, scratch_1); self.sample_extract(&tmp_glwe); } } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index 1a723cf..538256b 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -1,7 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; @@ -43,8 +44,8 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: ScratchAvailable + TakeVecZnxDft + TakeGLWECt, @@ -55,7 +56,7 @@ impl GLWECiphertext { assert_eq!(self.basek(), self.basek()); } - let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), lwe.basek(), lwe.k(), 1); + let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), lwe.basek(), lwe.k(), 1); glwe.data.zero(); let n_lwe: usize = lwe.n(); @@ -66,6 +67,6 @@ impl GLWECiphertext { glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); }); - self.keyswitch(module, &glwe, &ksk.0, scratch1); + self.keyswitch(module, &glwe, &ksk.0, scratch_1); } } diff --git a/poulpy-core/src/decryption/glwe_ct.rs b/poulpy-core/src/decryption/glwe_ct.rs index 9dca49d..c69fc56 100644 --- a/poulpy-core/src/decryption/glwe_ct.rs +++ b/poulpy-core/src/decryption/glwe_ct.rs @@ -1,7 +1,7 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, SvpApplyInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, + SvpApplyDftToDftInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, }, layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch}, }; @@ -26,9 +26,9 @@ impl GLWECiphertext { sk: &GLWESecretPrepared, scratch: &mut Scratch, ) where - Module: DFT - + SvpApplyInplace - + IDFTConsume + Module: VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize, @@ -50,9 +50,9 @@ impl GLWECiphertext { (1..cols).for_each(|i| { // ci_dft = DFT(a[i]) * DFT(s[i]) let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n(), 1, self.size()); // TODO optimize size when pt << ct - module.dft(1, 0, &mut ci_dft, 0, &self.data, i); - module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1); - let ci_big = module.vec_znx_idft_consume(ci_dft); + module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &self.data, i); + module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); + let ci_big = module.vec_znx_idft_apply_consume(ci_dft); // c0_big += a[i] * s[i] module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); diff --git a/poulpy-core/src/encryption/compressed/gglwe_atk.rs b/poulpy-core/src/encryption/compressed/gglwe_atk.rs index 4336030..5439a50 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_atk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_atk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, - TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwithcDegree, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, + VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -41,12 +41,12 @@ impl GGLWEAutomorphismKeyCompressed { Module: VecZnxAutomorphism + SvpPrepare + SvpPPolAllocBytes - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-core/src/encryption/compressed/gglwe_ct.rs b/poulpy-core/src/encryption/compressed/gglwe_ct.rs index fd71478..1951bf4 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, source::Source, @@ -37,9 +37,9 @@ impl GGLWECiphertextCompressed { Module: VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs index ce2a9a7..7f4d81f 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, - TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, - VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, - VecZnxSwithcDegree, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, + VecZnxSubABInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, source::Source, @@ -44,12 +44,12 @@ impl GGLWESwitchingKeyCompressed { ) where Module: SvpPrepare + SvpPPolAllocBytes - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -90,9 +90,9 @@ impl GGLWESwitchingKeyCompressed { let n: usize = sk_in.n().max(sk_out.n()); - let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank()); + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank()); (0..sk_in.rank()).for_each(|i| { - module.vec_znx_switch_degree( + module.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), i, &sk_in.data.as_vec_znx(), @@ -100,11 +100,11 @@ impl GGLWESwitchingKeyCompressed { ); }); - let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_prepared(n, sk_out.rank()); + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(n, sk_out.rank()); { - let (mut tmp, _) = scratch2.take_scalar_znx(n, 1); + let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); (0..sk_out.rank()).for_each(|i| { - module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); + module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); }); } @@ -115,7 +115,7 @@ impl GGLWESwitchingKeyCompressed { &sk_out_tmp, seed_xa, source_xe, - scratch2, + scratch_2, ); self.sk_in_n = sk_in.n(); self.sk_out_n = sk_out.n(); diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs index 5c38271..52c4ff1 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, + ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, + TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, + VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -33,13 +33,13 @@ impl GGLWETensorKeyCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpApply - + IDFTTmpA + Module: SvpApplyDftToDft + + VecZnxIdftApplyTmpA + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -48,7 +48,7 @@ impl GGLWETensorKeyCompressed { + VecZnxAddNormal + VecZnxNormalize + VecZnxSub - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxAddScalarInplace + SvpPrepare + SvpPPolAllocBytes @@ -70,39 +70,39 @@ impl GGLWETensorKeyCompressed { let n: usize = sk.n(); let rank: usize = self.rank(); - let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_prepared(n, rank); - sk_dft_prep.prepare(module, sk, scratch1); + let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank); + sk_dft_prep.prepare(module, sk, scratch_1); - let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1); + let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1); (0..rank).for_each(|i| { - module.dft(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); + module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); }); - let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1); - let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1); - let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1); + let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1); + let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1); let mut source_xa: Source = Source::new(seed_xa); (0..rank).for_each(|i| { (i..rank).for_each(|j| { - module.svp_apply(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); + module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); - module.idft_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); module.vec_znx_big_normalize( self.basek(), &mut sk_ij.data.as_vec_znx_mut(), 0, &sk_ij_big, 0, - scratch5, + scratch_5, ); let (seed_xa_tmp, _) = source_xa.branch(); self.at_mut(i, j) - .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch5); + .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5); }); }) } diff --git a/poulpy-core/src/encryption/compressed/ggsw_ct.rs b/poulpy-core/src/encryption/compressed/ggsw_ct.rs index d398b52..9d4efa2 100644 --- a/poulpy-core/src/encryption/compressed/ggsw_ct.rs +++ b/poulpy-core/src/encryption/compressed/ggsw_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, source::Source, @@ -37,9 +37,9 @@ impl GGSWCiphertextCompressed { Module: VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-core/src/encryption/compressed/glwe_ct.rs b/poulpy-core/src/encryption/compressed/glwe_ct.rs index dbec125..6b20eba 100644 --- a/poulpy-core/src/encryption/compressed/glwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/glwe_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, + VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -35,9 +35,9 @@ impl GLWECiphertextCompressed { ) where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -63,9 +63,9 @@ impl GLWECiphertextCompressed { ) where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs index b31234c..6eab79a 100644 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ b/poulpy-core/src/encryption/gglwe_atk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, - TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwithcDegree, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, + VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -41,9 +41,9 @@ impl GGLWEAutomorphismKey { Module: VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -53,7 +53,7 @@ impl GGLWEAutomorphismKey { + VecZnxNormalize + VecZnxSub + SvpPrepare - + VecZnxSwithcDegree + + VecZnxSwitchRing + SvpPPolAllocBytes + VecZnxAutomorphism, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, diff --git a/poulpy-core/src/encryption/gglwe_ct.rs b/poulpy-core/src/encryption/gglwe_ct.rs index 7d8781a..50dca97 100644 --- a/poulpy-core/src/encryption/gglwe_ct.rs +++ b/poulpy-core/src/encryption/gglwe_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, source::Source, @@ -41,9 +41,9 @@ impl GGLWECiphertext { Module: VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-core/src/encryption/gglwe_ksk.rs b/poulpy-core/src/encryption/gglwe_ksk.rs index 0f9e3f1..daf8e2e 100644 --- a/poulpy-core/src/encryption/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/gglwe_ksk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, - TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, - VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, - VecZnxSwithcDegree, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, + VecZnxSubABInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, source::Source, @@ -55,9 +55,9 @@ impl GGLWESwitchingKey { Module: VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -67,7 +67,7 @@ impl GGLWESwitchingKey { + VecZnxNormalize + VecZnxSub + SvpPrepare - + VecZnxSwithcDegree + + VecZnxSwitchRing + SvpPPolAllocBytes, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, { @@ -100,9 +100,9 @@ impl GGLWESwitchingKey { let n: usize = sk_in.n().max(sk_out.n()); - let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank()); + let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank()); (0..sk_in.rank()).for_each(|i| { - module.vec_znx_switch_degree( + module.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), i, &sk_in.data.as_vec_znx(), @@ -110,11 +110,11 @@ impl GGLWESwitchingKey { ); }); - let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_prepared(n, sk_out.rank()); + let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(n, sk_out.rank()); { - let (mut tmp, _) = scratch2.take_scalar_znx(n, 1); + let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); (0..sk_out.rank()).for_each(|i| { - module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); + module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); }); } @@ -125,7 +125,7 @@ impl GGLWESwitchingKey { &sk_out_tmp, source_xa, source_xe, - scratch2, + scratch_2, ); self.sk_in_n = sk_in.n(); self.sk_out_n = sk_out.n(); diff --git a/poulpy-core/src/encryption/gglwe_tsk.rs b/poulpy-core/src/encryption/gglwe_tsk.rs index 6b288f6..2032af5 100644 --- a/poulpy-core/src/encryption/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/gglwe_tsk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, - TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, + ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, + TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, + VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -41,14 +41,14 @@ impl GGLWETensorKey { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpApply - + IDFTTmpA + Module: SvpApplyDftToDft + + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -58,7 +58,7 @@ impl GGLWETensorKey { + VecZnxNormalize + VecZnxSub + SvpPrepare - + VecZnxSwithcDegree + + VecZnxSwitchRing + SvpPPolAllocBytes, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared + TakeVecZnxBig, @@ -73,35 +73,35 @@ impl GGLWETensorKey { let rank: usize = self.rank(); - let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_prepared(n, rank); - sk_dft_prep.prepare(module, sk, scratch1); + let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank); + sk_dft_prep.prepare(module, sk, scratch_1); - let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1); + let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1); (0..rank).for_each(|i| { - module.dft(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); + module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); }); - let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1); - let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1); - let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1); + let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1); + let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1); + let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1); (0..rank).for_each(|i| { (i..rank).for_each(|j| { - module.svp_apply(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); + module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); - module.idft_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); module.vec_znx_big_normalize( self.basek(), &mut sk_ij.data.as_vec_znx_mut(), 0, &sk_ij_big, 0, - scratch5, + scratch_5, ); self.at_mut(i, j) - .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch5); + .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch_5); }); }) } diff --git a/poulpy-core/src/encryption/ggsw_ct.rs b/poulpy-core/src/encryption/ggsw_ct.rs index 6b8484c..5995a50 100644 --- a/poulpy-core/src/encryption/ggsw_ct.rs +++ b/poulpy-core/src/encryption/ggsw_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero}, source::Source, @@ -40,9 +40,9 @@ impl GGSWCiphertext { Module: VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -67,14 +67,14 @@ impl GGSWCiphertext { let rank: usize = self.rank(); let digits: usize = self.digits(); - let (mut tmp_pt, scratch1) = scratch.take_glwe_pt(self.n(), basek, k); + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self.n(), basek, k); (0..self.rows()).for_each(|row_i| { tmp_pt.data.zero(); // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0); - module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch1); + module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_1); (0..rank + 1).for_each(|col_j| { // rlwe encrypt of vec_znx_pt into vec_znx_ct @@ -85,7 +85,7 @@ impl GGSWCiphertext { sk, source_xa, source_xe, - scratch1, + scratch_1, ); }); }); diff --git a/poulpy-core/src/encryption/glwe_ct.rs b/poulpy-core/src/encryption/glwe_ct.rs index 7d47785..ddb202e 100644 --- a/poulpy-core/src/encryption/glwe_ct.rs +++ b/poulpy-core/src/encryption/glwe_ct.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, + ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero}, source::Source, @@ -53,9 +53,9 @@ impl GLWECiphertext { ) where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -92,9 +92,9 @@ impl GLWECiphertext { ) where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -138,9 +138,9 @@ impl GLWECiphertext { ) where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -179,8 +179,8 @@ impl GLWECiphertext { scratch: &mut Scratch, ) where Module: SvpPrepare - + SvpApply - + IDFTConsume + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + VecZnxBigAddNormal + VecZnxBigAddSmallInplace + VecZnxBigNormalize, @@ -198,8 +198,8 @@ impl GLWECiphertext { scratch: &mut Scratch, ) where Module: SvpPrepare - + SvpApply - + IDFTConsume + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + VecZnxBigAddNormal + VecZnxBigAddSmallInplace + VecZnxBigNormalize, @@ -226,8 +226,8 @@ impl GLWECiphertext { scratch: &mut Scratch, ) where Module: SvpPrepare - + SvpApply - + IDFTConsume + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + VecZnxBigAddNormal + VecZnxBigAddSmallInplace + VecZnxBigNormalize, @@ -273,10 +273,10 @@ impl GLWECiphertext { (0..cols).for_each(|i| { let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), 1, size_pk); // ci_dft = DFT(u) * DFT(pk[i]) - module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); + module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); // ci_big = u * p[i] - let mut ci_big = module.vec_znx_idft_consume(ci_dft); + let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft); // ci_big = u * pk[i] + e module.vec_znx_big_add_normal(basek, &mut ci_big, 0, pk.k(), source_xe, SIGMA, SIGMA_BOUND); @@ -311,9 +311,9 @@ pub(crate) fn glwe_encrypt_sk_internal: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -350,7 +350,7 @@ pub(crate) fn glwe_encrypt_sk_internal = module.vec_znx_idft_consume(ci_dft); + module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); + let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft); // use c[0] as buffer, which is overwritten later by the normalization step module.vec_znx_big_normalize(basek, &mut ci, 0, &ci_big, 0, scratch_3); diff --git a/poulpy-core/src/encryption/glwe_pk.rs b/poulpy-core/src/encryption/glwe_pk.rs index d2cfc15..23820ff 100644 --- a/poulpy-core/src/encryption/glwe_pk.rs +++ b/poulpy-core/src/encryption/glwe_pk.rs @@ -1,7 +1,7 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigNormalize, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, DataMut, DataRef, Module, ScratchOwned}, @@ -22,9 +22,9 @@ impl GLWEPublicKey { Module:, Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs index 7aee17b..9c72d78 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, - TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwithcDegree, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, source::Source, @@ -38,13 +38,13 @@ impl GLWEToLWESwitchingKey { ) where DLwe: DataRef, DGlwe: DataRef, - Module: VecZnxAutomorphismInplace + Module: VecZnxAutomorphismInplace + VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -54,7 +54,7 @@ impl GLWEToLWESwitchingKey { + VecZnxNormalize + VecZnxSub + SvpPrepare - + VecZnxSwithcDegree + + VecZnxSwitchRing + SvpPPolAllocBytes, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, { @@ -63,10 +63,10 @@ impl GLWEToLWESwitchingKey { assert!(sk_lwe.n() <= module.n()); } - let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(sk_glwe.n(), 1); + let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), 1); sk_lwe_as_glwe.data.zero(); sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); - module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0); + module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); self.0.encrypt_sk( module, @@ -74,7 +74,7 @@ impl GLWEToLWESwitchingKey { &sk_lwe_as_glwe, source_xa, source_xe, - scratch1, + scratch_1, ); } } diff --git a/poulpy-core/src/encryption/lwe_ct.rs b/poulpy-core/src/encryption/lwe_ct.rs index 26400cc..15a9a65 100644 --- a/poulpy-core/src/encryption/lwe_ct.rs +++ b/poulpy-core/src/encryption/lwe_ct.rs @@ -32,7 +32,7 @@ impl LWECiphertext { let basek: usize = self.basek(); let k: usize = self.k(); - module.zn_fill_uniform(self.n() + 1, basek, &mut self.data, 0, k, source_xa); + module.zn_fill_uniform(self.n() + 1, basek, &mut self.data, 0, source_xa); let mut tmp_znx: Zn> = Zn::alloc(1, 1, self.size()); diff --git a/poulpy-core/src/encryption/lwe_ksk.rs b/poulpy-core/src/encryption/lwe_ksk.rs index e6b24db..2df695f 100644 --- a/poulpy-core/src/encryption/lwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_ksk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, - TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwithcDegree, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, source::Source, @@ -38,13 +38,13 @@ impl LWESwitchingKey { ) where DIn: DataRef, DOut: DataRef, - Module: VecZnxAutomorphismInplace + Module: VecZnxAutomorphismInplace + VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -54,7 +54,7 @@ impl LWESwitchingKey { + VecZnxNormalize + VecZnxSub + SvpPrepare - + VecZnxSwithcDegree + + VecZnxSwitchRing + SvpPPolAllocBytes, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, { @@ -65,16 +65,16 @@ impl LWESwitchingKey { assert!(self.n() <= module.n()); } - let (mut sk_in_glwe, scratch1) = scratch.take_glwe_secret(self.n(), 1); - let (mut sk_out_glwe, scratch2) = scratch1.take_glwe_secret(self.n(), 1); + let (mut sk_in_glwe, scratch_1) = scratch.take_glwe_secret(self.n(), 1); + let (mut sk_out_glwe, scratch_2) = scratch_1.take_glwe_secret(self.n(), 1); sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n()].copy_from_slice(sk_lwe_out.data.at(0, 0)); sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_out_glwe.data.as_vec_znx_mut(), 0); + module.vec_znx_automorphism_inplace(-1, &mut sk_out_glwe.data.as_vec_znx_mut(), 0, scratch_2); sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n()].copy_from_slice(sk_lwe_in.data.at(0, 0)); sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data.as_vec_znx_mut(), 0); + module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data.as_vec_znx_mut(), 0, scratch_2); self.0.encrypt_sk( module, @@ -82,7 +82,7 @@ impl LWESwitchingKey { &sk_out_glwe, source_xa, source_xe, - scratch2, + scratch_2, ); } } diff --git a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs index ac661e2..95ea310 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, - TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwithcDegree, + ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, source::Source, @@ -36,13 +36,13 @@ impl LWEToGLWESwitchingKey { ) where DLwe: DataRef, DGlwe: DataRef, - Module: VecZnxAutomorphismInplace + Module: VecZnxAutomorphismInplace + VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -52,7 +52,7 @@ impl LWEToGLWESwitchingKey { + VecZnxNormalize + VecZnxSub + SvpPrepare - + VecZnxSwithcDegree + + VecZnxSwitchRing + SvpPPolAllocBytes, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, { @@ -61,10 +61,10 @@ impl LWEToGLWESwitchingKey { assert!(sk_lwe.n() <= module.n()); } - let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(sk_glwe.n(), 1); + let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), 1); sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0)); sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n()..].fill(0); - module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0); + module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1); self.0.encrypt_sk( module, @@ -72,7 +72,7 @@ impl LWEToGLWESwitchingKey { sk_glwe, source_xa, source_xe, - scratch1, + scratch_1, ); } } diff --git a/poulpy-core/src/external_product/gglwe_atk.rs b/poulpy-core/src/external_product/gglwe_atk.rs index 058eb65..23a48c2 100644 --- a/poulpy-core/src/external_product/gglwe_atk.rs +++ b/poulpy-core/src/external_product/gglwe_atk.rs @@ -1,7 +1,7 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, + VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; @@ -51,10 +51,10 @@ impl GGLWEAutomorphismKey { Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, { @@ -70,10 +70,10 @@ impl GGLWEAutomorphismKey { Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, { diff --git a/poulpy-core/src/external_product/gglwe_ksk.rs b/poulpy-core/src/external_product/gglwe_ksk.rs index 30f202a..8a07977 100644 --- a/poulpy-core/src/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/external_product/gglwe_ksk.rs @@ -1,7 +1,7 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, + VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; @@ -51,10 +51,10 @@ impl GGLWESwitchingKey { Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, { @@ -106,10 +106,10 @@ impl GGLWESwitchingKey { Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, { diff --git a/poulpy-core/src/external_product/ggsw_ct.rs b/poulpy-core/src/external_product/ggsw_ct.rs index f28c7aa..0b72877 100644 --- a/poulpy-core/src/external_product/ggsw_ct.rs +++ b/poulpy-core/src/external_product/ggsw_ct.rs @@ -1,7 +1,7 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, + VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; @@ -51,10 +51,10 @@ impl GGSWCiphertext { Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, { @@ -116,10 +116,10 @@ impl GGSWCiphertext { Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, { diff --git a/poulpy-core/src/external_product/glwe_ct.rs b/poulpy-core/src/external_product/glwe_ct.rs index 8290e2e..9164b96 100644 --- a/poulpy-core/src/external_product/glwe_ct.rs +++ b/poulpy-core/src/external_product/glwe_ct.rs @@ -1,7 +1,7 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, + VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnxBig}, }; @@ -65,10 +65,10 @@ impl GLWECiphertext { Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, { @@ -101,8 +101,8 @@ impl GLWECiphertext { let cols: usize = rhs.rank() + 1; let digits: usize = rhs.digits(); - let (mut res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise - let (mut a_dft, scratch2) = scratch1.take_vec_znx_dft(self.n(), cols, lhs.size().div_ceil(digits)); + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, lhs.size().div_ceil(digits)); a_dft.data_mut().fill(0); @@ -121,21 +121,21 @@ impl GLWECiphertext { res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); (0..cols).for_each(|col_i| { - module.dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); + module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i); }); if di == 0 { - module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch2); + module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2); } else { - module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2); + module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2); } }); } - let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_consume(res_dft); + let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft); (0..cols).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1); + module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch_1); }); } @@ -148,16 +148,81 @@ impl GLWECiphertext { Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - self.external_product(module, &*self_ptr, rhs, scratch); + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + use poulpy_hal::api::ScratchAvailable; + + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(rhs.n(), self.n()); + assert!( + scratch.available() + >= GLWECiphertext::external_product_scratch_space( + module, + self.basek(), + self.k(), + self.k(), + rhs.k(), + rhs.digits(), + rhs.rank(), + ) + ); } + + let cols: usize = rhs.rank() + 1; + let digits: usize = rhs.digits(); + + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, self.size().div_ceil(digits)); + + a_dft.data_mut().fill(0); + + { + (0..digits).for_each(|di| { + // (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits) + a_dft.set_size((self.size() + di) / digits); + + // Small optimization for digits > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}. + // As such we can ignore the last digits-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last digits-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize); + + (0..cols).for_each(|col_i| { + module.vec_znx_dft_apply( + digits, + digits - 1 - di, + &mut a_dft, + col_i, + &self.data, + col_i, + ); + }); + + if di == 0 { + module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2); + } else { + module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2); + } + }); + } + + let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft); + + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch_1); + }); } } diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 0c63b71..c7c1f64 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -2,9 +2,9 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, - VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxRotate, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, + VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, @@ -126,20 +126,20 @@ impl GLWEPacker { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxCopy - + VecZnxRotateInplace + + VecZnxRotateInplace + VecZnxSub + VecZnxNegateInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxSubABInplace + VecZnxRotate - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + VecZnxBigAutomorphismInplace, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, @@ -204,20 +204,20 @@ fn pack_core( + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxCopy - + VecZnxRotateInplace + + VecZnxRotateInplace + VecZnxSub + VecZnxNegateInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxSubABInplace + VecZnxRotate - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + VecZnxBigAutomorphismInplace, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, @@ -301,20 +301,20 @@ fn combine( + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxCopy - + VecZnxRotateInplace + + VecZnxRotateInplace + VecZnxSub + VecZnxNegateInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxAddInplace + VecZnxNormalizeInplace + VecZnxSubABInplace + VecZnxRotate - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + VecZnxBigAutomorphismInplace, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, @@ -349,15 +349,15 @@ fn combine( let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); // a = a * X^-t - a.rotate_inplace(module, -t); + a.rotate_inplace(module, -t, scratch_1); // tmp_b = a * X^-t - b tmp_b.sub(module, a, b); - tmp_b.rsh(module, 1); + tmp_b.rsh(module, 1, scratch_1); // a = a * X^-t + b a.add_inplace(module, b); - a.rsh(module, 1); + a.rsh(module, 1, scratch_1); tmp_b.normalize_inplace(module, scratch_1); @@ -375,9 +375,9 @@ fn combine( // a = a + b * X^t - phi(a * X^-t - b) * X^t // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) // = a + b * X^t + phi(a - b * X^t) - a.rotate_inplace(module, t); + a.rotate_inplace(module, t, scratch_1); } else { - a.rsh(module, 1); + a.rsh(module, 1, scratch); // a = a + phi(a) if let Some(key) = auto_keys.get(&gal_el) { a.automorphism_add_inplace(module, key, scratch); @@ -388,7 +388,7 @@ fn combine( } else if let Some(b) = b { let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); tmp_b.rotate(module, 1 << (log_n - i - 1), b); - tmp_b.rsh(module, 1); + tmp_b.rsh(module, 1, scratch_1); // a = (b* X^t - phi(b* X^t)) if let Some(key) = auto_keys.get(&gal_el) { diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 4baf115..1c5c428 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -2,9 +2,9 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxRshInplace, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxRshInplace, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, }; @@ -73,12 +73,12 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxCopy, Scratch: TakeVecZnxDft + ScratchAvailable, { @@ -99,16 +99,16 @@ impl GLWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxBigAutomorphismInplace - + VecZnxRshInplace, + + VecZnxRshInplace, Scratch: TakeVecZnxDft + ScratchAvailable, { (start..end).for_each(|i| { - self.rsh(module, 1); + self.rsh(module, 1, scratch); let p: i64 = if i == 0 { -1 diff --git a/poulpy-core/src/keyswitching/gglwe_ct.rs b/poulpy-core/src/keyswitching/gglwe_ct.rs index 8f1890e..21ea399 100644 --- a/poulpy-core/src/keyswitching/gglwe_ct.rs +++ b/poulpy-core/src/keyswitching/gglwe_ct.rs @@ -1,7 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}, }; @@ -56,8 +57,8 @@ impl GGLWEAutomorphismKey { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, @@ -76,8 +77,8 @@ impl GGLWEAutomorphismKey { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable, @@ -132,8 +133,8 @@ impl GGLWESwitchingKey { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: ScratchAvailable + TakeVecZnxDft, @@ -161,6 +162,12 @@ impl GGLWESwitchingKey { self.rank_out(), rhs.rank_out() ); + assert!( + self.rows() <= lhs.rows(), + "self.rows()={} > lhs.rows()={}", + self.rows(), + lhs.rows() + ); } (0..self.rank_in()).for_each(|col_i| { @@ -188,8 +195,8 @@ impl GGLWESwitchingKey { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: ScratchAvailable + TakeVecZnxDft, diff --git a/poulpy-core/src/keyswitching/ggsw_ct.rs b/poulpy-core/src/keyswitching/ggsw_ct.rs index 5fcd1b2..6e9d52e 100644 --- a/poulpy-core/src/keyswitching/ggsw_ct.rs +++ b/poulpy-core/src/keyswitching/ggsw_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, - VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, + VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos}, @@ -114,13 +114,13 @@ impl GGSWCiphertext { + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VecZnxDftCopy + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftAddInplace + VecZnxBigNormalize - + IDFTTmpA, + + VecZnxIdftApplyTmpA, Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, { #[cfg(debug_assertions)] @@ -150,8 +150,8 @@ impl GGSWCiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxDftAllocBytes @@ -159,10 +159,15 @@ impl GGSWCiphertext { + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace - + IDFTTmpA, + + VecZnxIdftApplyTmpA, Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, { - self.keyswitch_internal(module, lhs, ksk, scratch); + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.at_mut(row_i, 0) + .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch); + }); self.expand_row(module, tsk, scratch); } @@ -178,8 +183,8 @@ impl GGSWCiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxDftAllocBytes @@ -187,13 +192,16 @@ impl GGSWCiphertext { + VecZnxNormalizeTmpBytes + VecZnxDftCopy + VecZnxDftAddInplace - + IDFTTmpA, + + VecZnxIdftApplyTmpA, Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, { - unsafe { - let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; - self.keyswitch(module, &*self_ptr, ksk, tsk, scratch); - } + (0..self.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.at_mut(row_i, 0) + .keyswitch_inplace(module, ksk, scratch); + }); + self.expand_row(module, tsk, scratch); } pub fn expand_row( @@ -206,13 +214,13 @@ impl GGSWCiphertext { + VmpApplyDftToDftTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes - + DFT + + VecZnxDftApply + VecZnxDftCopy + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftAddInplace + VecZnxBigNormalize - + IDFTTmpA, + + VecZnxIdftApplyTmpA, Scratch: ScratchAvailable + TakeVecZnxDft + TakeVecZnxBig, { assert!( @@ -234,9 +242,9 @@ impl GGSWCiphertext { // Keyswitch the j-th row of the col 0 (0..self.rows()).for_each(|row_i| { // Pre-compute DFT of (a0, a1, a2) - let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size()); + let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, self.size()); (0..cols).for_each(|i| { - module.dft(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); + module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i); }); (1..cols).for_each(|col_j| { @@ -262,8 +270,8 @@ impl GGSWCiphertext { let digits: usize = tsk.digits(); - let (mut tmp_dft_i, scratch2) = scratch1.take_vec_znx_dft(n, cols, tsk.size()); - let (mut tmp_a, scratch3) = scratch2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits)); + let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size()); + let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits)); { // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 @@ -295,9 +303,9 @@ impl GGSWCiphertext { module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); if di == 0 && col_i == 1 { - module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch3); + module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3); } else { - module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch3); + module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); } }); }); @@ -313,46 +321,19 @@ impl GGSWCiphertext { // = // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); - let (mut tmp_idft, scratch3) = scratch2.take_vec_znx_big(n, 1, tsk.size()); + let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size()); (0..cols).for_each(|i| { - module.idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); + module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); module.vec_znx_big_normalize( self.basek(), &mut self.at_mut(row_i, col_j).data, i, &tmp_idft, 0, - scratch3, + scratch_3, ); }); }) }) } - - fn keyswitch_internal( - &mut self, - module: &Module, - lhs: &GGSWCiphertext, - ksk: &GGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + DFT - + IDFTConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: ScratchAvailable + TakeVecZnxDft, - { - // Keyswitch the j-th row of the col 0 - (0..lhs.rows()).for_each(|row_i| { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.at_mut(row_i, 0) - .keyswitch(module, &lhs.at(row_i, 0), ksk, scratch); - }) - } } diff --git a/poulpy-core/src/keyswitching/glwe_ct.rs b/poulpy-core/src/keyswitching/glwe_ct.rs index e7e1ecd..14e23e0 100644 --- a/poulpy-core/src/keyswitching/glwe_ct.rs +++ b/poulpy-core/src/keyswitching/glwe_ct.rs @@ -1,7 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos}, }; @@ -117,6 +118,63 @@ impl GLWECiphertext { ) ); } + + #[allow(dead_code)] + pub(crate) fn assert_keyswitch_inplace( + &self, + module: &Module, + rhs: &GGLWESwitchingKeyPrepared, + scratch: &Scratch, + ) where + DataRhs: DataRef, + Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes, + Scratch: ScratchAvailable, + { + let basek: usize = self.basek(); + assert_eq!( + self.rank(), + rhs.rank_out(), + "self.rank(): {} != rhs.rank_out(): {}", + self.rank(), + rhs.rank_out() + ); + assert_eq!(self.basek(), basek); + assert_eq!(rhs.n(), self.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_scratch_space( + module, + self.basek(), + self.k(), + self.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + ), + "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space( + module, + self.basek(), + self.k(), + self.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + )={}", + scratch.available(), + GLWECiphertext::keyswitch_scratch_space( + module, + self.basek(), + self.k(), + self.k(), + rhs.k(), + rhs.digits(), + rhs.rank_in(), + rhs.rank_out(), + ) + ); + } } impl GLWECiphertext { @@ -130,11 +188,10 @@ impl GLWECiphertext { Module: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: ScratchAvailable + TakeVecZnxDft, @@ -143,10 +200,10 @@ impl GLWECiphertext { { self.assert_keyswitch(module, lhs, rhs, scratch); } - let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise - let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch1); + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise + let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch_1); (0..self.cols()).for_each(|i| { - module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1); + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); }) } @@ -162,16 +219,21 @@ impl GLWECiphertext { + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: ScratchAvailable + TakeVecZnxDft, { - unsafe { - let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; - self.keyswitch(module, &*self_ptr, rhs, scratch); + #[cfg(debug_assertions)] + { + self.assert_keyswitch_inplace(module, rhs, scratch); } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise + let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1); + (0..self.cols()).for_each(|i| { + module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1); + }) } } @@ -192,8 +254,8 @@ impl GLWECiphertext { + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: TakeVecZnxDft, @@ -224,16 +286,17 @@ where DataRes: DataMut, DataIn: DataRef, DataVmp: DataRef, - Module: VecZnxDftAllocBytes + DFT + VmpApplyDftToDft + IDFTConsume + VecZnxBigAddSmallInplace, + Module: + VecZnxDftAllocBytes + VecZnxDftApply + VmpApplyDftToDft + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace, Scratch: TakeVecZnxDft, { let cols: usize = a.cols(); - let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size()); + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size()); (0..cols - 1).for_each(|col_i| { - module.dft(1, 0, &mut ai_dft, col_i, a, col_i + 1); + module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1); }); - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1); - let mut res_big: VecZnxBig = module.vec_znx_idft_consume(res_dft); + module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); + let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); res_big } @@ -251,16 +314,16 @@ where DataIn: DataRef, DataVmp: DataRef, Module: VecZnxDftAllocBytes - + DFT + + VecZnxDftApply + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace, Scratch: TakeVecZnxDft, { let cols: usize = a.cols(); let size: usize = a.size(); - let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits)); + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits)); ai_dft.data_mut().fill(0); @@ -277,18 +340,18 @@ where res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize); (0..cols - 1).for_each(|col_i| { - module.dft(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1); + module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1); }); if di == 0 { - module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1); + module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1); } else { - module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch1); + module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1); } }); res_dft.set_size(res_dft.max_size()); - let mut res_big: VecZnxBig = module.vec_znx_idft_consume(res_dft); + let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res_dft); module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); res_big } diff --git a/poulpy-core/src/keyswitching/lwe_ct.rs b/poulpy-core/src/keyswitching/lwe_ct.rs index 7802e66..7588cb7 100644 --- a/poulpy-core/src/keyswitching/lwe_ct.rs +++ b/poulpy-core/src/keyswitching/lwe_ct.rs @@ -1,7 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; @@ -26,8 +27,8 @@ impl LWECiphertext> { + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, { @@ -51,8 +52,8 @@ impl LWECiphertext { + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, @@ -67,7 +68,7 @@ impl LWECiphertext { let max_k: usize = self.k().max(a.k()); let basek: usize = self.basek(); - let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1); + let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1); glwe.data.zero(); let n_lwe: usize = a.n(); @@ -78,7 +79,7 @@ impl LWECiphertext { glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); }); - glwe.keyswitch_inplace(module, &ksk.0, scratch1); + glwe.keyswitch_inplace(module, &ksk.0, scratch_1); self.sample_extract(&glwe); } diff --git a/poulpy-core/src/layouts/compressed/gglwe_atk.rs b/poulpy-core/src/layouts/compressed/gglwe_atk.rs index e3bee47..8c4c8c9 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_atk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_atk.rs @@ -24,8 +24,8 @@ impl fmt::Debug for GGLWEAutomorphismKeyCompressed { } impl FillUniform for GGLWEAutomorphismKeyCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.key.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.key.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_ct.rs b/poulpy-core/src/layouts/compressed/gglwe_ct.rs index 8dc710a..0f6db41 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ct.rs @@ -28,8 +28,8 @@ impl fmt::Debug for GGLWECiphertextCompressed { } impl FillUniform for GGLWECiphertextCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.data.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs index ab62073..72070b3 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs @@ -25,8 +25,8 @@ impl fmt::Debug for GGLWESwitchingKeyCompressed { } impl FillUniform for GGLWESwitchingKeyCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.key.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.key.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs index 3eaee4f..08917cd 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs @@ -23,10 +23,10 @@ impl fmt::Debug for GGLWETensorKeyCompressed { } impl FillUniform for GGLWETensorKeyCompressed { - fn fill_uniform(&mut self, source: &mut Source) { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() - .for_each(|key: &mut GGLWESwitchingKeyCompressed| key.fill_uniform(source)) + .for_each(|key: &mut GGLWESwitchingKeyCompressed| key.fill_uniform(log_bound, source)) } } diff --git a/poulpy-core/src/layouts/compressed/ggsw_ct.rs b/poulpy-core/src/layouts/compressed/ggsw_ct.rs index 446439d..5cb9d2d 100644 --- a/poulpy-core/src/layouts/compressed/ggsw_ct.rs +++ b/poulpy-core/src/layouts/compressed/ggsw_ct.rs @@ -49,8 +49,8 @@ impl Reset for GGSWCiphertextCompressed { } impl FillUniform for GGSWCiphertextCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.data.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/compressed/glwe_ct.rs b/poulpy-core/src/layouts/compressed/glwe_ct.rs index 5c0f9f8..8c8eaf9 100644 --- a/poulpy-core/src/layouts/compressed/glwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/glwe_ct.rs @@ -48,8 +48,8 @@ impl Reset for GLWECiphertextCompressed { } impl FillUniform for GLWECiphertextCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.data.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); } } @@ -160,7 +160,8 @@ impl GLWECiphertext { { #[cfg(debug_assertions)] { - assert_eq!(self.rank(), other.rank()) + assert_eq!(self.rank(), other.rank()); + debug_assert_eq!(self.size(), other.size()); } let k: usize = other.k; @@ -168,7 +169,7 @@ impl GLWECiphertext { let cols: usize = other.rank() + 1; module.vec_znx_copy(&mut self.data, 0, &other.data, 0); (1..cols).for_each(|i| { - module.vec_znx_fill_uniform(basek, &mut self.data, i, k, source); + module.vec_znx_fill_uniform(basek, &mut self.data, i, source); }); self.basek = basek; diff --git a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs index d9dbdbf..fb7f959 100644 --- a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs @@ -2,9 +2,9 @@ use std::fmt; use poulpy_hal::{ api::{ - DFT, IDFTConsume, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, source::Source, @@ -22,8 +22,8 @@ impl fmt::Debug for GLWEToLWESwitchingKeyCompressed { } impl FillUniform for GLWEToLWESwitchingKeyCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.0.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.0.fill_uniform(log_bound, source); } } @@ -96,9 +96,9 @@ impl GLWEToLWESwitchingKeyCompressed> { where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-core/src/layouts/compressed/lwe_ct.rs b/poulpy-core/src/layouts/compressed/lwe_ct.rs index d58a24d..159b107 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/lwe_ct.rs @@ -50,8 +50,8 @@ where } impl FillUniform for LWECiphertextCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.data.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); } } @@ -124,15 +124,9 @@ where Module: ZnFillUniform, { fn decompress(&mut self, module: &Module, other: &LWECiphertextCompressed) { + debug_assert_eq!(self.size(), other.size()); let mut source: Source = Source::new(other.seed); - module.zn_fill_uniform( - self.n(), - other.basek(), - &mut self.data, - 0, - other.k(), - &mut source, - ); + module.zn_fill_uniform(self.n(), other.basek(), &mut self.data, 0, &mut source); (0..self.size()).for_each(|i| { self.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; }); diff --git a/poulpy-core/src/layouts/compressed/lwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_ksk.rs index 0f4d603..23ee722 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_ksk.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, source::Source, @@ -24,8 +24,8 @@ impl fmt::Debug for LWESwitchingKeyCompressed { } impl FillUniform for LWESwitchingKeyCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.0.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.0.fill_uniform(log_bound, source); } } @@ -98,9 +98,9 @@ impl LWESwitchingKeyCompressed> { where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs index f930824..e9023c8 100644 --- a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, - VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, + VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, }, layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, Reset, WriterTo}, source::Source, @@ -24,8 +24,8 @@ impl fmt::Debug for LWEToGLWESwitchingKeyCompressed { } impl FillUniform for LWEToGLWESwitchingKeyCompressed { - fn fill_uniform(&mut self, source: &mut Source) { - self.0.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.0.fill_uniform(log_bound, source); } } @@ -98,9 +98,9 @@ impl LWEToGLWESwitchingKeyCompressed> { where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-core/src/layouts/gglwe_atk.rs b/poulpy-core/src/layouts/gglwe_atk.rs index 2318307..785e854 100644 --- a/poulpy-core/src/layouts/gglwe_atk.rs +++ b/poulpy-core/src/layouts/gglwe_atk.rs @@ -21,8 +21,8 @@ impl fmt::Debug for GGLWEAutomorphismKey { } impl FillUniform for GGLWEAutomorphismKey { - fn fill_uniform(&mut self, source: &mut Source) { - self.key.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.key.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/gglwe_ct.rs b/poulpy-core/src/layouts/gglwe_ct.rs index e1e8c53..5e64873 100644 --- a/poulpy-core/src/layouts/gglwe_ct.rs +++ b/poulpy-core/src/layouts/gglwe_ct.rs @@ -23,8 +23,8 @@ impl fmt::Debug for GGLWECiphertext { } impl FillUniform for GGLWECiphertext { - fn fill_uniform(&mut self, source: &mut Source) { - self.data.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/gglwe_ksk.rs b/poulpy-core/src/layouts/gglwe_ksk.rs index 611be36..5652904 100644 --- a/poulpy-core/src/layouts/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/gglwe_ksk.rs @@ -32,8 +32,8 @@ impl fmt::Display for GGLWESwitchingKey { } impl FillUniform for GGLWESwitchingKey { - fn fill_uniform(&mut self, source: &mut Source) { - self.key.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.key.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/gglwe_tsk.rs b/poulpy-core/src/layouts/gglwe_tsk.rs index f3173b3..47c2993 100644 --- a/poulpy-core/src/layouts/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/gglwe_tsk.rs @@ -20,10 +20,10 @@ impl fmt::Debug for GGLWETensorKey { } impl FillUniform for GGLWETensorKey { - fn fill_uniform(&mut self, source: &mut Source) { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() - .for_each(|key: &mut GGLWESwitchingKey| key.fill_uniform(source)) + .for_each(|key: &mut GGLWESwitchingKey| key.fill_uniform(log_bound, source)) } } diff --git a/poulpy-core/src/layouts/ggsw_ct.rs b/poulpy-core/src/layouts/ggsw_ct.rs index db85e8d..39e3cfc 100644 --- a/poulpy-core/src/layouts/ggsw_ct.rs +++ b/poulpy-core/src/layouts/ggsw_ct.rs @@ -40,8 +40,8 @@ impl Reset for GGSWCiphertext { } impl FillUniform for GGSWCiphertext { - fn fill_uniform(&mut self, source: &mut Source) { - self.data.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/glwe_ct.rs b/poulpy-core/src/layouts/glwe_ct.rs index ff78f32..a19deb9 100644 --- a/poulpy-core/src/layouts/glwe_ct.rs +++ b/poulpy-core/src/layouts/glwe_ct.rs @@ -55,8 +55,8 @@ where } impl FillUniform for GLWECiphertext { - fn fill_uniform(&mut self, source: &mut Source) { - self.data.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs index 9622336..8194a0a 100644 --- a/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/glwe_to_lwe_ksk.rs @@ -18,8 +18,8 @@ impl fmt::Debug for GLWEToLWESwitchingKey { } impl FillUniform for GLWEToLWESwitchingKey { - fn fill_uniform(&mut self, source: &mut Source) { - self.0.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.0.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/lwe_ct.rs b/poulpy-core/src/layouts/lwe_ct.rs index 7ceeed6..3b7c48c 100644 --- a/poulpy-core/src/layouts/lwe_ct.rs +++ b/poulpy-core/src/layouts/lwe_ct.rs @@ -54,8 +54,8 @@ impl FillUniform for LWECiphertext where Zn: FillUniform, { - fn fill_uniform(&mut self, source: &mut Source) { - self.data.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/lwe_ksk.rs b/poulpy-core/src/layouts/lwe_ksk.rs index 08852ca..632b43f 100644 --- a/poulpy-core/src/layouts/lwe_ksk.rs +++ b/poulpy-core/src/layouts/lwe_ksk.rs @@ -23,8 +23,8 @@ impl fmt::Debug for LWESwitchingKey { } impl FillUniform for LWESwitchingKey { - fn fill_uniform(&mut self, source: &mut Source) { - self.0.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.0.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs index 769193a..af27cda 100644 --- a/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/lwe_to_glwe_ksk.rs @@ -17,8 +17,8 @@ impl fmt::Debug for LWEToGLWESwitchingKey { } impl FillUniform for LWEToGLWESwitchingKey { - fn fill_uniform(&mut self, source: &mut Source) { - self.0.fill_uniform(source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.0.fill_uniform(log_bound, source); } } diff --git a/poulpy-core/src/layouts/prepared/glwe_pk.rs b/poulpy-core/src/layouts/prepared/glwe_pk.rs index 774e9ff..9b976c7 100644 --- a/poulpy-core/src/layouts/prepared/glwe_pk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_pk.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::{DFT, VecZnxDftAlloc, VecZnxDftAllocBytes}, + api::{VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply}, layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft}, }; @@ -64,7 +64,7 @@ impl GLWEPublicKeyPrepared, B> { impl PrepareAlloc, B>> for GLWEPublicKey where - Module: VecZnxDftAlloc + DFT, + Module: VecZnxDftAlloc + VecZnxDftApply, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEPublicKeyPrepared, B> { let mut pk_prepared: GLWEPublicKeyPrepared, B> = @@ -76,7 +76,7 @@ where impl Prepare> for GLWEPublicKeyPrepared where - Module: DFT, + Module: VecZnxDftApply, { fn prepare(&mut self, module: &Module, other: &GLWEPublicKey, _scratch: &mut Scratch) { #[cfg(debug_assertions)] @@ -86,7 +86,7 @@ where } (0..self.cols()).for_each(|i| { - module.dft(1, 0, &mut self.data, i, &other.data, i); + module.vec_znx_dft_apply(1, 0, &mut self.data, i, &other.data, i); }); self.k = other.k; self.basek = other.basek; diff --git a/poulpy-core/src/noise/gglwe_ct.rs b/poulpy-core/src/noise/gglwe_ct.rs index bb81ef6..8de8716 100644 --- a/poulpy-core/src/noise/gglwe_ct.rs +++ b/poulpy-core/src/noise/gglwe_ct.rs @@ -1,7 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, VecZnxSubScalarInplace, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, + VecZnxNormalizeTmpBytes, VecZnxSubScalarInplace, }, layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, ZnxZero}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, @@ -21,9 +22,9 @@ impl GGLWECiphertext { DataWant: DataRef, Module: VecZnxDftAllocBytes + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize diff --git a/poulpy-core/src/noise/ggsw_ct.rs b/poulpy-core/src/noise/ggsw_ct.rs index 489b025..ab8c2c4 100644 --- a/poulpy-core/src/noise/ggsw_ct.rs +++ b/poulpy-core/src/noise/ggsw_ct.rs @@ -1,8 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxAddScalarInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, VecZnxSubABInplace, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAddInplace, + VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, + VecZnxNormalizeTmpBytes, VecZnxSubABInplace, }, layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, @@ -22,9 +23,9 @@ impl GGSWCiphertext { DataScalar: DataRef, Module: VecZnxDftAllocBytes + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize @@ -32,7 +33,7 @@ impl GGSWCiphertext { + VecZnxBigAlloc + VecZnxDftAlloc + VecZnxBigNormalizeTmpBytes - + IDFTTmpA + + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxSubABInplace, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, @@ -56,9 +57,9 @@ impl GGSWCiphertext { // mul with sk[col_j-1] if col_j > 0 { - module.dft(1, 0, &mut pt_dft, 0, &pt.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); - module.idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); + module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); + module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); } @@ -89,9 +90,9 @@ impl GGSWCiphertext { DataScalar: DataRef, Module: VecZnxDftAllocBytes + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize @@ -99,7 +100,7 @@ impl GGSWCiphertext { + VecZnxBigAlloc + VecZnxDftAlloc + VecZnxBigNormalizeTmpBytes - + IDFTTmpA + + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxSubABInplace, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, @@ -122,9 +123,9 @@ impl GGSWCiphertext { // mul with sk[col_j-1] if col_j > 0 { - module.dft(1, 0, &mut pt_dft, 0, &pt.data, 0); - module.svp_apply_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); - module.idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); + module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); + module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); } diff --git a/poulpy-core/src/noise/glwe_ct.rs b/poulpy-core/src/noise/glwe_ct.rs index e4b08c0..bde5b15 100644 --- a/poulpy-core/src/noise/glwe_ct.rs +++ b/poulpy-core/src/noise/glwe_ct.rs @@ -1,8 +1,8 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, - VecZnxSubABInplace, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSubABInplace, }, layouts::{Backend, DataRef, Module, ScratchOwned}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, @@ -26,9 +26,9 @@ impl GLWECiphertext { DataPt: DataRef, Module: VecZnxDftAllocBytes + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index f6081c1..e977762 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -202,14 +202,14 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { self.set_k(set_k_unary(self, a)) } - fn rotate_inplace(&mut self, module: &Module, k: i64) + fn rotate_inplace(&mut self, module: &Module, k: i64, scratch: &mut Scratch) where - Module: VecZnxRotateInplace, + Module: VecZnxRotateInplace, { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); (0..self_mut.rank() + 1).for_each(|i| { - module.vec_znx_rotate_inplace(k, &mut self_mut.data, i); + module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch); }); } @@ -235,14 +235,14 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { self.set_k(set_k_unary(self, a)) } - fn mul_xp_minus_one_inplace(&mut self, module: &Module, k: i64) + fn mul_xp_minus_one_inplace(&mut self, module: &Module, k: i64, scratch: &mut Scratch) where - Module: VecZnxMulXpMinusOneInplace, + Module: VecZnxMulXpMinusOneInplace, { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut(); (0..self_mut.rank() + 1).for_each(|i| { - module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i); + module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch); }); } @@ -268,12 +268,14 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized { self.set_basek(a.basek()); } - fn rsh(&mut self, module: &Module, k: usize) + fn rsh(&mut self, module: &Module, k: usize, scratch: &mut Scratch) where - Module: VecZnxRshInplace, + Module: VecZnxRshInplace, { let basek: usize = self.basek(); - module.vec_znx_rsh_inplace(basek, k, &mut self.to_mut().data); + (0..self.cols()).for_each(|i| { + module.vec_znx_rsh_inplace(basek, k, &mut self.to_mut().data, i, scratch); + }) } fn normalize(&mut self, module: &Module, a: &A, scratch: &mut Scratch) diff --git a/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs deleted file mode 100644 index 5e9d245..0000000 --- a/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs +++ /dev/null @@ -1,356 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxCopy, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, - VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWEAutomorphismKey, GLWEPlaintext, GLWESecret, Infos, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, - }, - noise::log2_std_noise_gglwe_product, -}; - -#[allow(clippy::too_many_arguments)] -pub fn test_gglwe_automorphism_key_automorphism( - module: &Module, - p0: i64, - p1: i64, - basek: usize, - digits: usize, - k_in: usize, - k_out: usize, - k_apply: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + DFT - + IDFTConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + SvpPPolAllocBytes - + VecZnxDftAllocBytes - + VecZnxNormalizeTmpBytes - + VmpPMatAlloc - + VmpPrepare - + SvpPrepare - + SvpApplyInplace - + VecZnxAddScalarInplace - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxSwithcDegree - + SvpPPolAlloc - + VecZnxBigAddInplace - + VecZnxSubScalarInplace, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxDftImpl - + TakeVecZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl, -{ - let n: usize = module.n(); - let digits_in: usize = 1; - - let rows_in: usize = k_in / (basek * digits); - let rows_apply: usize = k_in.div_ceil(basek * digits); - - let mut auto_key_in: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_out: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_out, rows_in, digits_in, rank); - let mut auto_key_apply: GGLWEAutomorphismKey> = - GGLWEAutomorphismKey::alloc(n, basek, k_apply, rows_apply, digits, rank); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_apply, rank) - | GGLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_in, k_apply, digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - // gglwe_{s1}(s0) = s0 -> s1 - auto_key_in.encrypt_sk( - module, - p0, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - // gglwe_{s2}(s1) -> s1 -> s2 - auto_key_apply.encrypt_sk( - module, - p1, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_apply, rows_apply, digits, rank); - - auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); - - // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - auto_key_out.automorphism( - module, - &auto_key_in, - &auto_key_apply_prepared, - scratch.borrow(), - ); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_out); - - let mut sk_auto: GLWESecret> = GLWESecret::alloc(n, rank); - sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk - (0..rank).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p0 * p1), - &mut sk_auto.data.as_vec_znx_mut(), - i, - &sk.data.as_vec_znx(), - i, - ); - }); - - let sk_auto_dft: GLWESecretPrepared, B> = sk_auto.prepare_alloc(module, scratch.borrow()); - - (0..auto_key_out.rank_in()).for_each(|col_i| { - (0..auto_key_out.rows()).for_each(|row_i| { - auto_key_out - .at(row_i, col_i) - .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace( - &mut pt.data, - 0, - (digits_in - 1) + row_i * digits_in, - &sk.data, - col_i, - ); - - let noise_have: f64 = pt.data.std(basek, 0).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - n as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_out, - k_apply, - ); - - assert!( - noise_have < noise_want + 0.5, - "{} {}", - noise_have, - noise_want - ); - }); - }); -} - -#[allow(clippy::too_many_arguments)] -pub fn test_gglwe_automorphism_key_automorphism_inplace( - module: &Module, - p0: i64, - p1: i64, - basek: usize, - digits: usize, - k_in: usize, - k_apply: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxDftAllocBytes - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + DFT - + IDFTConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphism - + VecZnxSwithcDegree - + VecZnxAddScalarInplace - + VecZnxAutomorphism - + VecZnxAutomorphismInplace - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxDftImpl - + TakeVecZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl, -{ - let n: usize = module.n(); - let digits_in: usize = 1; - - let rows_in: usize = k_in / (basek * digits); - let rows_apply: usize = k_in.div_ceil(basek * digits); - - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); - let mut auto_key_apply: GGLWEAutomorphismKey> = - GGLWEAutomorphismKey::alloc(n, basek, k_apply, rows_apply, digits, rank); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_apply, rank) - | GGLWEAutomorphismKey::automorphism_inplace_scratch_space(module, basek, k_in, k_apply, digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - // gglwe_{s1}(s0) = s0 -> s1 - auto_key.encrypt_sk( - module, - p0, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - // gglwe_{s2}(s1) -> s1 -> s2 - auto_key_apply.encrypt_sk( - module, - p1, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_apply, rows_apply, digits, rank); - - auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); - - // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - auto_key.automorphism_inplace(module, &auto_key_apply_prepared, scratch.borrow()); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); - - let mut sk_auto: GLWESecret> = GLWESecret::alloc(n, rank); - sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk - - (0..rank).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p0 * p1), - &mut sk_auto.data.as_vec_znx_mut(), - i, - &sk.data.as_vec_znx(), - i, - ); - }); - - let sk_auto_dft: GLWESecretPrepared, B> = sk_auto.prepare_alloc(module, scratch.borrow()); - - (0..auto_key.rank_in()).for_each(|col_i| { - (0..auto_key.rows()).for_each(|row_i| { - auto_key - .at(row_i, col_i) - .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace( - &mut pt.data, - 0, - (digits_in - 1) + row_i * digits_in, - &sk.data, - col_i, - ); - - let noise_have: f64 = pt.data.std(basek, 0).log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - n as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_in, - k_apply, - ); - - assert!( - noise_have < noise_want + 0.5, - "{} {}", - noise_have, - noise_want - ); - }); - }); -} diff --git a/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs deleted file mode 100644 index 9aab7b3..0000000 --- a/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs +++ /dev/null @@ -1,326 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, - SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, - VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, - VecZnxDftCopy, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, - VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWEAutomorphismKey, GGLWETensorKey, GGSWCiphertext, GLWESecret, - prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, - }, - noise::noise_ggsw_keyswitch, -}; - -#[allow(clippy::too_many_arguments)] -pub fn test_ggsw_automorphism( - p: i64, - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - k_tsk: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + IDFTTmpA - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxAddScalarInplace - + VecZnxCopy - + VecZnxSubABInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxFillUniform - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpApply - + VecZnxSwithcDegree - + VecZnxAutomorphismInplace - + VecZnxAutomorphism, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); - let rows_in: usize = k_in.div_euclid(basek * digits); - - let digits_in: usize = 1; - - let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows_in, digits_in, rank); - let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows_in, digits_in, rank); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_tsk, rows, digits, rank); - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); - let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::automorphism_scratch_space( - module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, - ), - ); - - let var_xs: f64 = 0.5; - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(var_xs, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - auto_key.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - tensor_key.encrypt_sk( - module, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - - ct_in.encrypt_sk( - module, - &pt_scalar, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); - auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = - GGLWETensorKeyPrepared::alloc(module, basek, k_tsk, rows, digits, rank); - tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); - - ct_out.automorphism( - module, - &ct_in, - &auto_key_prepared, - &tsk_prepared, - scratch.borrow(), - ); - - module.vec_znx_automorphism_inplace(p, &mut pt_scalar.as_vec_znx_mut(), 0); - - let max_noise = |col_j: usize| -> f64 { - noise_ggsw_keyswitch( - n as f64, - basek * digits, - col_j, - var_xs, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_in, - k_ksk, - k_tsk, - ) + 0.5 - }; - - ct_out.assert_noise(module, &sk_prepared, &pt_scalar, max_noise); -} - -#[allow(clippy::too_many_arguments)] -pub fn test_ggsw_automorphism_inplace( - p: i64, - module: &Module, - basek: usize, - k_ct: usize, - k_ksk: usize, - k_tsk: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + IDFTTmpA - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxAddScalarInplace - + VecZnxCopy - + VecZnxSubABInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigAddSmallInplace - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxFillUniform - + SvpApply - + VecZnxSwithcDegree - + VecZnxAutomorphismInplace - + VecZnxAutomorphism, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(digits * basek); - let rows_in: usize = k_ct.div_euclid(basek * digits); - let digits_in: usize = 1; - - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows_in, digits_in, rank); - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_tsk, rows, digits, rank); - let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); - let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::automorphism_inplace_scratch_space(module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), - ); - - let var_xs: f64 = 0.5; - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(var_xs, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - auto_key.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - tensor_key.encrypt_sk( - module, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - - ct.encrypt_sk( - module, - &pt_scalar, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); - auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - - let mut tsk_prepared: GGLWETensorKeyPrepared, B> = - GGLWETensorKeyPrepared::alloc(module, basek, k_tsk, rows, digits, rank); - tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); - - ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow()); - - module.vec_znx_automorphism_inplace(p, &mut pt_scalar.as_vec_znx_mut(), 0); - - let max_noise = |col_j: usize| -> f64 { - noise_ggsw_keyswitch( - n as f64, - basek * digits, - col_j, - var_xs, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_ct, - k_ksk, - k_tsk, - ) + 0.5 - }; - - ct.assert_noise(module, &sk_prepared, &pt_scalar, max_noise); -} diff --git a/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs b/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs deleted file mode 100644 index 5565947..0000000 --- a/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs +++ /dev/null @@ -1,265 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, - VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, - }, - noise::log2_std_noise_gglwe_product, -}; - -#[allow(clippy::too_many_arguments)] -pub fn test_glwe_automorphism( - module: &Module, - basek: usize, - p: i64, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxAutomorphism - + VecZnxSwithcDegree - + VecZnxAddScalarInplace - + VecZnxAutomorphismInplace - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); - - let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) - | GLWECiphertext::automorphism_scratch_space( - module, - basek, - ct_out.k(), - ct_in.k(), - autokey.k(), - digits, - rank, - ), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - autokey.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct_in.encrypt_sk( - module, - &pt_want, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); - autokey_prepared.prepare(module, &autokey, scratch.borrow()); - - ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow()); - - let max_noise: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_in, - k_ksk, - ); - - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); - - ct_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); -} - -#[allow(clippy::too_many_arguments)] -pub fn test_glwe_automorphism_inplace( - module: &Module, - basek: usize, - p: i64, - k_ct: usize, - k_ksk: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxAutomorphism - + VecZnxSwithcDegree - + VecZnxAddScalarInplace - + VecZnxAutomorphismInplace - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::automorphism_inplace_scratch_space(module, basek, ct.k(), autokey.k(), digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - autokey.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct.encrypt_sk( - module, - &pt_want, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); - autokey_prepared.prepare(module, &autokey, scratch.borrow()); - - ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow()); - - let max_noise: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_ct, - k_ksk, - ); - - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0); - - ct.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); -} diff --git a/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs b/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs deleted file mode 100644 index 40e9658..0000000 --- a/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs +++ /dev/null @@ -1,202 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, - VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxCopy, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, - VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWEAutomorphismKey, GLWESecret, - compressed::{Decompress, GGLWEAutomorphismKeyCompressed}, - prepared::{GLWESecretPrepared, PrepareAlloc}, - }, -}; - -pub fn test_gglwe_automorphisk_key_encrypt_sk(module: &Module, basek: usize, k_ksk: usize, digits: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigAddSmallInplace - + VecZnxAutomorphism - + VecZnxSwithcDegree - + VecZnxAddScalarInplace - + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxDftImpl - + TakeVecZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl, -{ - let n: usize = module.n(); - let rows: usize = (k_ksk - digits * basek) / (digits * basek); - - let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( - module, basek, k_ksk, rank, - )); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let p = -5; - - atk.encrypt_sk( - module, - p, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut sk_out: GLWESecret> = sk.clone(); - (0..atk.rank()).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p), - &mut sk_out.data.as_vec_znx_mut(), - i, - &sk.data.as_vec_znx(), - i, - ); - }); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - - atk.key - .key - .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); -} - -pub fn test_gglwe_automorphisk_key_compressed_encrypt_sk( - module: &Module, - basek: usize, - k_ksk: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigAddSmallInplace - + VecZnxAutomorphism - + VecZnxSwithcDegree - + VecZnxAddScalarInplace - + VecZnxAutomorphismInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxDftImpl - + TakeVecZnxImpl - + TakeSvpPPolImpl - + TakeVecZnxBigImpl, -{ - let n: usize = module.n(); - let rows: usize = (k_ksk - digits * basek) / (digits * basek); - - let mut atk_compressed: GGLWEAutomorphismKeyCompressed> = - GGLWEAutomorphismKeyCompressed::alloc(n, basek, k_ksk, rows, digits, rank); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( - module, basek, k_ksk, rank, - )); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - - let p = -5; - - let seed_xa: [u8; 32] = [1u8; 32]; - - atk_compressed.encrypt_sk(module, p, &sk, seed_xa, &mut source_xe, scratch.borrow()); - - let mut sk_out: GLWESecret> = sk.clone(); - (0..atk_compressed.rank()).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p), - &mut sk_out.data.as_vec_znx_mut(), - i, - &sk.data.as_vec_znx(), - i, - ); - }); - let sk_out_prepared = sk_out.prepare_alloc(module, scratch.borrow()); - - let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); - atk.decompress(module, &atk_compressed); - - atk.key - .key - .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); -} diff --git a/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs b/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs deleted file mode 100644 index d771455..0000000 --- a/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs +++ /dev/null @@ -1,185 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, - VecZnxSwithcDegree, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWESwitchingKey, GLWESecret, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, - prepared::{GLWESecretPrepared, PrepareAlloc}, - }, -}; - -pub fn test_gglwe_switching_key_encrypt_sk( - module: &Module, - basek: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxSwithcDegree - + VecZnxAddScalarInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = (k_ksk - digits * basek) / (digits * basek); - - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( - module, basek, k_ksk, rank_in, rank_out, - )); - - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); - sk_in.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ksk.key - .assert_noise(module, &sk_out_prepared, &sk_in.data, SIGMA); -} - -pub fn test_gglwe_switching_key_compressed_encrypt_sk( - module: &Module, - basek: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxSwithcDegree - + VecZnxAddScalarInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = (k_ksk - digits * basek) / (digits * basek); - - let mut ksk_compressed: GGLWESwitchingKeyCompressed> = - GGLWESwitchingKeyCompressed::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space( - module, basek, k_ksk, rank_in, rank_out, - )); - - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); - sk_in.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - - let seed_xa = [1u8; 32]; - - ksk_compressed.encrypt_sk( - module, - &sk_in, - &sk_out, - seed_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); - ksk.decompress(module, &ksk_compressed); - - ksk.key - .assert_noise(module, &sk_out_prepared, &sk_in.data, SIGMA); -} diff --git a/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs b/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs deleted file mode 100644 index 2299961..0000000 --- a/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs +++ /dev/null @@ -1,181 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, - VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGSWCiphertext, GLWESecret, - compressed::{Decompress, GGSWCiphertextCompressed}, - prepared::{GLWESecretPrepared, PrepareAlloc}, - }, -}; - -pub fn test_ggsw_encrypt_sk(module: &Module, basek: usize, k: usize, digits: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxAddScalarInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + IDFTTmpA, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = (k - digits * basek) / (digits * basek); - - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k, rows, digits, rank); - - let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertext::encrypt_sk_scratch_space( - module, basek, k, rank, - )); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - ct.encrypt_sk( - module, - &pt_scalar, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; - - ct.assert_noise(module, &sk_prepared, &pt_scalar, noise_f); -} - -pub fn test_ggsw_compressed_encrypt_sk(module: &Module, basek: usize, k: usize, digits: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxAddScalarInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VmpPMatAlloc - + VmpPrepare - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + IDFTTmpA, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = (k - digits * basek) / (digits * basek); - - let mut ct_compressed: GGSWCiphertextCompressed> = GGSWCiphertextCompressed::alloc(n, basek, k, rows, digits, rank); - - let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - - pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertextCompressed::encrypt_sk_scratch_space( - module, basek, k, rank, - )); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - let seed_xa: [u8; 32] = [1u8; 32]; - - ct_compressed.encrypt_sk( - module, - &pt_scalar, - &sk_prepared, - seed_xa, - &mut source_xe, - scratch.borrow(), - ); - - let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; - - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k, rows, digits, rank); - ct.decompress(module, &ct_compressed); - - ct.assert_noise(module, &sk_prepared, &pt_scalar, noise_f); -} diff --git a/poulpy-core/src/tests/generics/encryption/glwe_ct.rs b/poulpy-core/src/tests/generics/encryption/glwe_ct.rs deleted file mode 100644 index 553888b..0000000 --- a/poulpy-core/src/tests/generics/encryption/glwe_ct.rs +++ /dev/null @@ -1,371 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxFillUniform, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, - compressed::{Decompress, GLWECiphertextCompressed}, - prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared, PrepareAlloc}, - }, - operations::GLWEOperations, -}; - -pub fn test_glwe_encrypt_sk(module: &Module, basek: usize, k_ct: usize, k_pt: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + SvpPPolAllocBytes - + SvpPrepare - + SvpApply - + IDFTConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n = module.n(); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_pt, &mut source_xa); - - ct.encrypt_sk( - module, - &pt_want, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); - - pt_want.sub_inplace_ab(module, &pt_have); - - let noise_have: f64 = pt_want.data.std(basek, 0) * (ct.k() as f64).exp2(); - let noise_want: f64 = SIGMA; - - assert!(noise_have <= noise_want + 0.2); -} - -pub fn test_glwe_compressed_encrypt_sk(module: &Module, basek: usize, k_ct: usize, k_pt: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + SvpPPolAllocBytes - + SvpPrepare - + SvpApply - + IDFTConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxCopy, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n = module.n(); - let mut ct_compressed: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(n, basek, k_ct, rank); - - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertextCompressed::encrypt_sk_scratch_space(module, basek, k_ct) - | GLWECiphertext::decrypt_scratch_space(module, basek, k_ct), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_pt, &mut source_xa); - - let seed_xa: [u8; 32] = [1u8; 32]; - - ct_compressed.encrypt_sk( - module, - &pt_want, - &sk_prepared, - seed_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - ct.decompress(module, &ct_compressed); - - ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); - - pt_want.sub_inplace_ab(module, &pt_have); - - let noise_have: f64 = pt_want.data.std(basek, 0) * (ct.k() as f64).exp2(); - let noise_want: f64 = SIGMA; - - assert!( - noise_have <= noise_want + 0.2, - "{} <= {}", - noise_have, - noise_want + 0.2 - ); -} - -pub fn test_glwe_encrypt_zero_sk(module: &Module, basek: usize, k_ct: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxDftAllocBytes - + VecZnxBigAllocBytes - + SvpPPolAllocBytes - + SvpPrepare - + SvpApply - + IDFTConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalizeTmpBytes - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n = module.n(); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([1u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::decrypt_scratch_space(module, basek, k_ct) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - - ct.encrypt_zero_sk( - module, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - - assert!((SIGMA - pt.data.std(basek, 0) * (k_ct as f64).exp2()) <= 0.2); -} - -pub fn test_glwe_encrypt_pk(module: &Module, basek: usize, k_ct: usize, k_pk: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VecZnxDftAlloc - + SvpApply - + VecZnxBigAddNormal, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - let mut source_xu: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) - | GLWECiphertext::encrypt_pk_scratch_space(module, basek, k_pk), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(n, basek, k_pk, rank); - pk.generate_from_sk(module, &sk_prepared, &mut source_xa, &mut source_xe); - - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); - - let pk_prepared: GLWEPublicKeyPrepared, B> = pk.prepare_alloc(module, scratch.borrow()); - - ct.encrypt_pk( - module, - &pt_want, - &pk_prepared, - &mut source_xu, - &mut source_xe, - scratch.borrow(), - ); - - ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); - - pt_want.sub_inplace_ab(module, &pt_have); - - let noise_have: f64 = pt_want.data.std(basek, 0).log2(); - let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64); - - assert!( - noise_have <= noise_want + 0.2, - "{} {}", - noise_have, - noise_want - ); -} diff --git a/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs b/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs deleted file mode 100644 index 206c922..0000000 --- a/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs +++ /dev/null @@ -1,241 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, - SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, - VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree, - }, - layouts::{Backend, Module, ScratchOwned, VecZnxDft}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWETensorKey, GLWEPlaintext, GLWESecret, Infos, - compressed::{Decompress, GGLWETensorKeyCompressed}, - prepared::{GLWESecretPrepared, PrepareAlloc}, - }, -}; - -pub fn test_glwe_tensor_key_encrypt_sk(module: &Module, basek: usize, k: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VecZnxDftAlloc - + SvpApply - + VecZnxBigAlloc - + IDFTTmpA - + VecZnxAddScalarInplace - + VecZnxSwithcDegree - + VecZnxSubScalarInplace, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = k / basek; - - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k, rows, 1, rank); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKey::encrypt_sk_scratch_space( - module, - basek, - tensor_key.k(), - rank, - )); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - tensor_key.encrypt_sk( - module, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); - - let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); - let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); - - (0..rank).for_each(|i| { - module.dft(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - }); - - (0..rank).for_each(|i| { - (0..rank).for_each(|j| { - module.svp_apply(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); - module.idft_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - basek, - &mut sk_ij.data.as_vec_znx_mut(), - 0, - &sk_ij_big, - 0, - scratch.borrow(), - ); - (0..tensor_key.rank_in()).for_each(|col_i| { - (0..tensor_key.rows()).for_each(|row_i| { - tensor_key - .at(i, j) - .at(row_i, col_i) - .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); - - let std_pt: f64 = pt.data.std(basek, 0) * (k as f64).exp2(); - assert!((SIGMA - std_pt).abs() <= 0.5, "{} {}", SIGMA, std_pt); - }); - }); - }) - }) -} - -pub fn test_glwe_tensor_key_compressed_encrypt_sk(module: &Module, basek: usize, k: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAddSmallInplace - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxCopy - + VecZnxDftAlloc - + SvpApply - + VecZnxBigAlloc - + IDFTTmpA - + VecZnxAddScalarInplace - + VecZnxSwithcDegree - + VecZnxSubScalarInplace, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = k / basek; - - let mut tensor_key_compressed: GGLWETensorKeyCompressed> = - GGLWETensorKeyCompressed::alloc(n, basek, k, rows, 1, rank); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKeyCompressed::encrypt_sk_scratch_space( - module, - basek, - tensor_key_compressed.k(), - rank, - )); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - let seed_xa: [u8; 32] = [1u8; 32]; - - tensor_key_compressed.encrypt_sk(module, &sk, seed_xa, &mut source_xe, scratch.borrow()); - - let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k, rows, 1, rank); - tensor_key.decompress(module, &tensor_key_compressed); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); - - let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); - let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); - - (0..rank).for_each(|i| { - module.dft(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - }); - - (0..rank).for_each(|i| { - (0..rank).for_each(|j| { - module.svp_apply(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); - module.idft_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - basek, - &mut sk_ij.data.as_vec_znx_mut(), - 0, - &sk_ij_big, - 0, - scratch.borrow(), - ); - (0..tensor_key.rank_in()).for_each(|col_i| { - (0..tensor_key.rows()).for_each(|row_i| { - tensor_key - .at(i, j) - .at(row_i, col_i) - .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); - - let std_pt: f64 = pt.data.std(basek, 0) * (k as f64).exp2(); - assert!((SIGMA - std_pt).abs() <= 0.5, "{} {}", SIGMA, std_pt); - }); - }); - }) - }) -} diff --git a/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs deleted file mode 100644 index 655258e..0000000 --- a/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs +++ /dev/null @@ -1,304 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, - VecZnxSubScalarInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned, ZnxViewMut}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWESwitchingKey, GGSWCiphertext, GLWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, - }, - noise::noise_ggsw_product, -}; - -#[allow(clippy::too_many_arguments)] -pub fn test_gglwe_switching_key_external_product( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank_in: usize, - rank_out: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxSwithcDegree - + VecZnxAddScalarInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VecZnxRotateInplace - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); - let digits_in: usize = 1; - - let mut ct_gglwe_in: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_in, rows, digits_in, rank_in, rank_out); - let mut ct_gglwe_out: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_out, rows, digits_in, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank_out); - - let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_in, rank_in, rank_out) - | GGLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), - ); - - let r: usize = 1; - - pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - - let var_xs: f64 = 0.5; - - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); - sk_in.fill_ternary_prob(var_xs, &mut source_xs); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); - sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - - // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_in.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - module, - &pt_rgsw, - &sk_out_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ct_rgsw_prepared: GGSWCiphertextPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); - - // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) - ct_gglwe_out.external_product(module, &ct_gglwe_in, &ct_rgsw_prepared, scratch.borrow()); - - (0..rank_in).for_each(|i| { - module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data.as_vec_znx_mut(), i); // * X^{r} - }); - - let var_gct_err_lhs: f64 = SIGMA * SIGMA; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / n as f64; // X^{k} - let var_a0_err: f64 = SIGMA * SIGMA; - let var_a1_err: f64 = 1f64 / 12f64; - - let max_noise: f64 = noise_ggsw_product( - n as f64, - basek * digits, - var_xs, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank_out as f64, - k_in, - k_ggsw, - ); - - ct_gglwe_out - .key - .assert_noise(module, &sk_out_prepared, &sk_in.data, max_noise + 0.5); -} - -#[allow(clippy::too_many_arguments)] -pub fn test_gglwe_switching_key_external_product_inplace( - module: &Module, - basek: usize, - k_ct: usize, - k_ggsw: usize, - digits: usize, - rank_in: usize, - rank_out: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxSwithcDegree - + VecZnxAddScalarInplace - + VecZnxSubScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VecZnxRotateInplace - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VmpPrepare, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * digits); - - let digits_in: usize = 1; - - let mut ct_gglwe: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ct, rows, digits_in, rank_in, rank_out); - let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank_out); - - let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ct, rank_in, rank_out) - | GGLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, digits, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), - ); - - let r: usize = 1; - - pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} - - let var_xs: f64 = 0.5; - - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); - sk_in.fill_ternary_prob(var_xs, &mut source_xs); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); - sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - - // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct_rgsw.encrypt_sk( - module, - &pt_rgsw, - &sk_out_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ct_rgsw_prepared: GGSWCiphertextPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); - - // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) - ct_gglwe.external_product_inplace(module, &ct_rgsw_prepared, scratch.borrow()); - - (0..rank_in).for_each(|i| { - module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data.as_vec_znx_mut(), i); // * X^{r} - }); - - let var_gct_err_lhs: f64 = SIGMA * SIGMA; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / n as f64; // X^{k} - let var_a0_err: f64 = SIGMA * SIGMA; - let var_a1_err: f64 = 1f64 / 12f64; - - let max_noise: f64 = noise_ggsw_product( - n as f64, - basek * digits, - var_xs, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank_out as f64, - k_ct, - k_ggsw, - ); - - ct_gglwe - .key - .assert_noise(module, &sk_out_prepared, &sk_in.data, max_noise + 0.5); -} diff --git a/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs b/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs deleted file mode 100644 index a957aef..0000000 --- a/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs +++ /dev/null @@ -1,290 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, - SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, - VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, - VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned, ZnxViewMut}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGSWCiphertext, GLWESecret, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, - }, - noise::noise_ggsw_product, -}; - -#[allow(clippy::too_many_arguments)] -pub fn test_ggsw_external_product( - module: &Module, - basek: usize, - k_in: usize, - k_out: usize, - k_ggsw: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxAddScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VecZnxRotateInplace - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VmpPrepare - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + IDFTTmpA, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); - let rows_in: usize = k_in.div_euclid(basek * digits); - let digits_in: usize = 1; - - let mut ct_ggsw_lhs_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows_in, digits_in, rank); - let mut ct_ggsw_lhs_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows_in, digits_in, rank); - let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); - let mut pt_ggsw_lhs: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_ggsw_rhs: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); - - let k: usize = 1; - - pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - ct_ggsw_rhs.encrypt_sk( - module, - &pt_ggsw_rhs, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct_ggsw_lhs_in.encrypt_sk( - module, - &pt_ggsw_lhs, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ct_ggsw_rhs.prepare_alloc(module, scratch.borrow()); - - ct_ggsw_lhs_out.external_product(module, &ct_ggsw_lhs_in, &ct_rhs_prepared, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs.as_vec_znx_mut(), 0); - - let var_gct_err_lhs: f64 = SIGMA * SIGMA; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / n as f64; // X^{k} - let var_a0_err: f64 = SIGMA * SIGMA; - let var_a1_err: f64 = 1f64 / 12f64; - - let max_noise = |_col_j: usize| -> f64 { - noise_ggsw_product( - n as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_in, - k_ggsw, - ) + 0.5 - }; - - ct_ggsw_lhs_out.assert_noise(module, &sk_prepared, &pt_ggsw_lhs, max_noise); -} - -#[allow(clippy::too_many_arguments)] -pub fn test_ggsw_external_product_inplace( - module: &Module, - basek: usize, - k_ct: usize, - k_ggsw: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxAddScalarInplace - + VecZnxCopy - + VmpPMatAlloc - + VecZnxRotateInplace - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VmpPrepare - + VecZnxBigAlloc - + VecZnxDftAlloc - + VecZnxBigNormalizeTmpBytes - + IDFTTmpA, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl - + VecZnxDftAllocBytesImpl - + VecZnxBigAllocBytesImpl - + TakeSvpPPolImpl, -{ - let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(digits * basek); - let rows_in: usize = k_ct.div_euclid(basek * digits); - let digits_in: usize = 1; - - let mut ct_ggsw_lhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows_in, digits_in, rank); - let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); - - let mut pt_ggsw_lhs: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_ggsw_rhs: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); - - let k: usize = 1; - - pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - ct_ggsw_rhs.encrypt_sk( - module, - &pt_ggsw_rhs, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct_ggsw_lhs.encrypt_sk( - module, - &pt_ggsw_lhs, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ct_ggsw_rhs.prepare_alloc(module, scratch.borrow()); - - ct_ggsw_lhs.external_product_inplace(module, &ct_rhs_prepared, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs.as_vec_znx_mut(), 0); - - let var_gct_err_lhs: f64 = SIGMA * SIGMA; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / n as f64; // X^{k} - let var_a0_err: f64 = SIGMA * SIGMA; - let var_a1_err: f64 = 1f64 / 12f64; - - let max_noise = |_col_j: usize| -> f64 { - noise_ggsw_product( - n as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_ct, - k_ggsw, - ) + 0.5 - }; - - ct_ggsw_lhs.assert_noise(module, &sk_prepared, &pt_ggsw_lhs, max_noise); -} diff --git a/poulpy-core/src/tests/generics/external_product/glwe_ct.rs b/poulpy-core/src/tests/generics/external_product/glwe_ct.rs deleted file mode 100644 index cdfcbf4..0000000 --- a/poulpy-core/src/tests/generics/external_product/glwe_ct.rs +++ /dev/null @@ -1,282 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScratchOwned, ZnxViewMut}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, - prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, - }, - noise::noise_ggsw_product, -}; - -#[allow(clippy::too_many_arguments)] -pub fn test_glwe_external_product( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ggsw: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VecZnxRotateInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank); - let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank); - let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe_in.k()) - | GLWECiphertext::external_product_scratch_space( - module, - basek, - ct_glwe_out.k(), - ct_glwe_in.k(), - ct_ggsw.k(), - digits, - rank, - ), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - ct_ggsw.encrypt_sk( - module, - &pt_rgsw, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct_glwe_in.encrypt_sk( - module, - &pt_want, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ct_ggsw.prepare_alloc(module, scratch.borrow()); - - ct_glwe_out.external_product(module, &ct_glwe_in, &ct_ggsw_prepared, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); - - let var_gct_err_lhs: f64 = SIGMA * SIGMA; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / n as f64; // X^{k} - let var_a0_err: f64 = SIGMA * SIGMA; - let var_a1_err: f64 = 1f64 / 12f64; - - let max_noise: f64 = noise_ggsw_product( - n as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_in, - k_ggsw, - ); - - ct_glwe_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); -} - -#[allow(clippy::too_many_arguments)] -pub fn test_glwe_external_product_inplace( - module: &Module, - basek: usize, - k_ct: usize, - k_ggsw: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VecZnxRotateInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - // Random input plaintext - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); - - pt_want.data.at_mut(0, 0)[1] = 1; - - let k: usize = 1; - - pt_rgsw.raw_mut()[k] = 1; // X^{k} - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space(module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - ct_ggsw.encrypt_sk( - module, - &pt_rgsw, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct_glwe.encrypt_sk( - module, - &pt_want, - &sk_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ct_ggsw.prepare_alloc(module, scratch.borrow()); - - ct_glwe.external_product_inplace(module, &ct_ggsw_prepared, scratch.borrow()); - - module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0); - - let var_gct_err_lhs: f64 = SIGMA * SIGMA; - let var_gct_err_rhs: f64 = 0f64; - - let var_msg: f64 = 1f64 / n as f64; // X^{k} - let var_a0_err: f64 = SIGMA * SIGMA; - let var_a1_err: f64 = 1f64 / 12f64; - - let max_noise: f64 = noise_ggsw_product( - n as f64, - basek * digits, - 0.5, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - rank as f64, - k_ct, - k_ggsw, - ); - - ct_glwe.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); -} diff --git a/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs deleted file mode 100644 index 57cd36f..0000000 --- a/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs +++ /dev/null @@ -1,303 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxFillUniform, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, - VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWESwitchingKey, GLWESecret, - prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, - }, - noise::log2_std_noise_gglwe_product, -}; - -#[allow(clippy::too_many_arguments)] -pub fn test_gglwe_switching_key_keyswitch( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in_s0s1: usize, - rank_out_s0s1: usize, - rank_out_s1s2: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree - + VecZnxSubScalarInplace, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); - let digits_in: usize = 1; - - let mut ct_gglwe_s0s1: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_in, rows, digits_in, rank_in_s0s1, rank_out_s0s1); - let mut ct_gglwe_s1s2: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_out_s0s1, rank_out_s1s2); - let mut ct_gglwe_s0s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc( - n, - basek, - k_out, - rows, - digits_in, - rank_in_s0s1, - rank_out_s1s2, - ); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( - module, - basek, - k_ksk, - rank_in_s0s1 | rank_out_s0s1, - rank_out_s0s1 | rank_out_s1s2, - )); - let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_scratch_space( - module, - basek, - k_out, - k_in, - k_ksk, - digits, - ct_gglwe_s1s2.rank_in(), - ct_gglwe_s1s2.rank_out(), - )); - - let mut sk0: GLWESecret> = GLWESecret::alloc(n, rank_in_s0s1); - sk0.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk1: GLWESecret> = GLWESecret::alloc(n, rank_out_s0s1); - sk1.fill_ternary_prob(0.5, &mut source_xs); - - let mut sk2: GLWESecret> = GLWESecret::alloc(n, rank_out_s1s2); - sk2.fill_ternary_prob(0.5, &mut source_xs); - let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); - - // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.encrypt_sk( - module, - &sk0, - &sk1, - &mut source_xa, - &mut source_xe, - scratch_enc.borrow(), - ); - - // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.encrypt_sk( - module, - &sk1, - &sk2, - &mut source_xa, - &mut source_xe, - scratch_enc.borrow(), - ); - - let ct_gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = - ct_gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); - - // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - ct_gglwe_s0s2.keyswitch( - module, - &ct_gglwe_s0s1, - &ct_gglwe_s1s2_prepared, - scratch_apply.borrow(), - ); - - let max_noise: f64 = log2_std_noise_gglwe_product( - n as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - SIGMA * SIGMA, - 0f64, - rank_out_s0s1 as f64, - k_in, - k_ksk, - ); - - ct_gglwe_s0s2 - .key - .assert_noise(module, &sk2_prepared, &sk0.data, max_noise + 0.5); -} - -#[allow(clippy::too_many_arguments)] -pub fn test_gglwe_switching_key_keyswitch_inplace( - module: &Module, - basek: usize, - k_ct: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree - + VecZnxSubScalarInplace, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * digits); - let digits_in: usize = 1; - - let mut ct_gglwe_s0s1: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_ct, rows, digits_in, rank_in, rank_out); - let mut ct_gglwe_s1s2: GGLWESwitchingKey> = - GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_out, rank_out); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( - module, - basek, - k_ksk, - rank_in | rank_out, - rank_out, - )); - let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_inplace_scratch_space( - module, basek, k_ct, k_ksk, digits, rank_out, - )); - - let var_xs: f64 = 0.5; - - let mut sk0: GLWESecret> = GLWESecret::alloc(n, rank_in); - sk0.fill_ternary_prob(var_xs, &mut source_xs); - - let mut sk1: GLWESecret> = GLWESecret::alloc(n, rank_out); - sk1.fill_ternary_prob(var_xs, &mut source_xs); - - let mut sk2: GLWESecret> = GLWESecret::alloc(n, rank_out); - sk2.fill_ternary_prob(var_xs, &mut source_xs); - let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); - - // gglwe_{s1}(s0) = s0 -> s1 - ct_gglwe_s0s1.encrypt_sk( - module, - &sk0, - &sk1, - &mut source_xa, - &mut source_xe, - scratch_enc.borrow(), - ); - - // gglwe_{s2}(s1) -> s1 -> s2 - ct_gglwe_s1s2.encrypt_sk( - module, - &sk1, - &sk2, - &mut source_xa, - &mut source_xe, - scratch_enc.borrow(), - ); - - let ct_gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = - ct_gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); - - // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) - ct_gglwe_s0s1.keyswitch_inplace(module, &ct_gglwe_s1s2_prepared, scratch_apply.borrow()); - - let ct_gglwe_s0s2: GGLWESwitchingKey> = ct_gglwe_s0s1; - - let max_noise: f64 = log2_std_noise_gglwe_product( - n as f64, - basek * digits, - var_xs, - var_xs, - 0f64, - SIGMA * SIGMA, - 0f64, - rank_out as f64, - k_ct, - k_ksk, - ); - - ct_gglwe_s0s2 - .key - .assert_noise(module, &sk2_prepared, &sk0.data, max_noise + 0.5); -} diff --git a/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs deleted file mode 100644 index c6d700e..0000000 --- a/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs +++ /dev/null @@ -1,305 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, - SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxFillUniform, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScalarZnx, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GLWESecret, - prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc}, - }, - noise::noise_ggsw_keyswitch, -}; - -#[allow(clippy::too_many_arguments)] -pub fn test_ggsw_keyswitch( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - k_tsk: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree - + SvpApply - + IDFTTmpA - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxBigAlloc - + VecZnxDftAlloc, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let rows: usize = k_in.div_ceil(digits * basek); - - let digits_in: usize = 1; - - let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows, digits_in, rank); - let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows, digits_in, rank); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_ksk, rows, digits, rank); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); - let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::keyswitch_scratch_space( - module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, - ), - ); - - let var_xs: f64 = 0.5; - - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); - sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); - sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - tsk.encrypt_sk( - module, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - - ct_in.encrypt_sk( - module, - &pt_scalar, - &sk_in_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); - let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); - - ct_out.keyswitch( - module, - &ct_in, - &ksk_prepared, - &tsk_prepared, - scratch.borrow(), - ); - - let max_noise = |col_j: usize| -> f64 { - noise_ggsw_keyswitch( - n as f64, - basek * digits, - col_j, - var_xs, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_in, - k_ksk, - k_tsk, - ) + 0.5 - }; - - ct_out.assert_noise(module, &sk_out_prepared, &pt_scalar, max_noise); -} - -#[allow(clippy::too_many_arguments)] -pub fn test_ggsw_keyswitch_inplace( - module: &Module, - basek: usize, - k_ct: usize, - k_ksk: usize, - k_tsk: usize, - digits: usize, - rank: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree - + SvpApply - + IDFTTmpA - + VecZnxDftCopy - + VecZnxDftAddInplace - + VecZnxBigAlloc - + VecZnxDftAlloc, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(digits * basek); - - let digits_in: usize = 1; - - let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows, digits_in, rank); - let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_tsk, rows, digits, rank); - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); - let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) - | GGSWCiphertext::keyswitch_inplace_scratch_space(module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), - ); - - let var_xs: f64 = 0.5; - - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); - sk_in.fill_ternary_prob(var_xs, &mut source_xs); - let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); - sk_out.fill_ternary_prob(var_xs, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - tsk.encrypt_sk( - module, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - pt_scalar.fill_ternary_hw(0, n, &mut source_xs); - - ct.encrypt_sk( - module, - &pt_scalar, - &sk_in_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); - let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); - - ct.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow()); - - let max_noise = |col_j: usize| -> f64 { - noise_ggsw_keyswitch( - n as f64, - basek * digits, - col_j, - var_xs, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_ct, - k_ksk, - k_tsk, - ) + 0.5 - }; - - ct.assert_noise(module, &sk_out_prepared, &pt_scalar, max_noise); -} diff --git a/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs deleted file mode 100644 index f80e7c7..0000000 --- a/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs +++ /dev/null @@ -1,251 +0,0 @@ -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, - VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxFillUniform, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWESwitchingKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, - prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, - }, - noise::log2_std_noise_gglwe_product, -}; - -#[allow(clippy::too_many_arguments)] -pub fn test_glwe_keyswitch( - module: &Module, - basek: usize, - k_out: usize, - k_in: usize, - k_ksk: usize, - digits: usize, - rank_in: usize, - rank_out: usize, -) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n = module.n(); - let rows: usize = k_in.div_ceil(basek * digits); - - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); - let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank_in); - let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank_out); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank_in, rank_out) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) - | GLWECiphertext::keyswitch_scratch_space( - module, - basek, - ct_out.k(), - ct_in.k(), - ksk.k(), - digits, - rank_in, - rank_out, - ), - ); - - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); - sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct_in.encrypt_sk( - module, - &pt_want, - &sk_in_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); - - ct_out.keyswitch(module, &ct_in, &ksk_prepared, scratch.borrow()); - - let max_noise: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - SIGMA * SIGMA, - 0f64, - rank_in as f64, - k_in, - k_ksk, - ); - - ct_out.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); -} - -pub fn test_glwe_keyswitch_inplace(module: &Module, basek: usize, k_ct: usize, k_ksk: usize, digits: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let rows: usize = k_ct.div_ceil(basek * digits); - - let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); - let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank, rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, ct_glwe.k(), ksk.k(), digits, rank), - ); - - let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); - sk_in.fill_ternary_prob(0.5, &mut source_xs); - let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); - - let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); - sk_out.fill_ternary_prob(0.5, &mut source_xs); - let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); - - ksk.encrypt_sk( - module, - &sk_in, - &sk_out, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - ct_glwe.encrypt_sk( - module, - &pt_want, - &sk_in_prepared, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); - - ct_glwe.keyswitch_inplace(module, &ksk_prepared, scratch.borrow()); - - let max_noise: f64 = log2_std_noise_gglwe_product( - module.n() as f64, - basek * digits, - 0.5, - 0.5, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_ct, - k_ksk, - ); - - ct_glwe.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); -} diff --git a/poulpy-core/src/tests/generics/trace.rs b/poulpy-core/src/tests/generics/trace.rs deleted file mode 100644 index e94d997..0000000 --- a/poulpy-core/src/tests/generics/trace.rs +++ /dev/null @@ -1,167 +0,0 @@ -use std::collections::HashMap; - -use poulpy_hal::{ - api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxFillUniform, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, - VecZnxSubABInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, - VmpPrepare, - }, - layouts::{Backend, Module, ScratchOwned, ZnxView, ZnxViewMut}, - oep::{ - ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, - TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, - }, - source::Source, -}; - -use crate::{ - encryption::SIGMA, - layouts::{ - GGLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, - prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, - }, - noise::var_noise_gglwe_product, -}; - -pub fn test_glwe_trace_inplace(module: &Module, basek: usize, k: usize, rank: usize) -where - Module: VecZnxDftAllocBytes - + VecZnxAutomorphism - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallBInplace - + VecZnxRshInplace - + VecZnxRotateInplace - + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume - + VecZnxFillUniform - + VecZnxSubABInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc - + VecZnxBigAllocBytes - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxNormalizeTmpBytes - + VecZnxAddScalarInplace - + VmpPMatAlloc - + VmpPrepare - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree - + VecZnxCopy, - B: Backend - + TakeVecZnxDftImpl - + TakeVecZnxBigImpl - + TakeSvpPPolImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + ScratchAvailableImpl - + TakeScalarZnxImpl - + TakeVecZnxImpl, -{ - let n: usize = module.n(); - let k_autokey: usize = k + basek; - - let digits: usize = 1; - let rows: usize = k.div_ceil(basek * digits); - - let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k, rank); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); - - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_autokey, rank) - | GLWECiphertext::trace_inplace_scratch_space(module, basek, ct.k(), k_autokey, digits, rank), - ); - - let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); - sk.fill_ternary_prob(0.5, &mut source_xs); - let sk_dft: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); - - let mut data_want: Vec = vec![0i64; n]; - - data_want - .iter_mut() - .for_each(|x| *x = source_xa.next_i64() & 0xFF); - - module.vec_znx_fill_uniform(basek, &mut pt_have.data, 0, k, &mut source_xa); - - ct.encrypt_sk( - module, - &pt_have, - &sk_dft, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - - let mut auto_keys: HashMap, B>> = HashMap::new(); - let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); - let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_autokey, rows, digits, rank); - gal_els.iter().for_each(|gal_el| { - tmp.encrypt_sk( - module, - *gal_el, - &sk, - &mut source_xa, - &mut source_xe, - scratch.borrow(), - ); - let atk_prepared: GGLWEAutomorphismKeyPrepared, B> = tmp.prepare_alloc(module, scratch.borrow()); - auto_keys.insert(*gal_el, atk_prepared); - }); - - ct.trace_inplace(module, 0, 5, &auto_keys, scratch.borrow()); - ct.trace_inplace(module, 5, module.log_n(), &auto_keys, scratch.borrow()); - - (0..pt_want.size()).for_each(|i| pt_want.data.at_mut(0, i)[0] = pt_have.data.at(0, i)[0]); - - ct.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); - - module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); - module.vec_znx_normalize_inplace(basek, &mut pt_want.data, 0, scratch.borrow()); - - let noise_have: f64 = pt_want.std().log2(); - - let mut noise_want: f64 = var_noise_gglwe_product( - n as f64, - basek, - 0.5, - 0.5, - 1.0 / 12.0, - SIGMA * SIGMA, - 0.0, - rank as f64, - k, - k_autokey, - ); - noise_want += SIGMA * SIGMA * (-2.0 * (k) as f64).exp2(); - noise_want += n as f64 * 1.0 / 12.0 * 0.5 * rank as f64 * (-2.0 * (k) as f64).exp2(); - noise_want = noise_want.sqrt().log2(); - - assert!( - (noise_have - noise_want).abs() < 1.0, - "{} > {}", - noise_have, - noise_want - ); -} diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/gglwe.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/gglwe.rs deleted file mode 100644 index 13e476b..0000000 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/gglwe.rs +++ /dev/null @@ -1,246 +0,0 @@ -use poulpy_backend::cpu_spqlios::FFT64; -use poulpy_hal::{api::ModuleNew, layouts::Module}; - -use crate::tests::generics::{ - automorphism::{test_gglwe_automorphism_key_automorphism, test_gglwe_automorphism_key_automorphism_inplace}, - encryption::{ - test_gglwe_automorphisk_key_compressed_encrypt_sk, test_gglwe_automorphisk_key_encrypt_sk, - test_gglwe_switching_key_compressed_encrypt_sk, test_gglwe_switching_key_encrypt_sk, - test_glwe_tensor_key_compressed_encrypt_sk, test_glwe_tensor_key_encrypt_sk, - }, - external_product::{test_gglwe_switching_key_external_product, test_gglwe_switching_key_external_product_inplace}, - keyswitch::{test_gglwe_switching_key_keyswitch, test_gglwe_switching_key_keyswitch_inplace}, -}; - -#[test] -fn gglwe_switching_key_encrypt_sk() { - let log_n: usize = 8; - let module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ksk: usize = 54; - let digits: usize = k_ksk / basek; - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - println!( - "test_gglwe_switching_key_encrypt_sk digits: {} ranks: ({} {})", - di, rank_in, rank_out - ); - test_gglwe_switching_key_encrypt_sk(&module, basek, k_ksk, di, rank_in, rank_out); - }); - }); - }); -} - -#[test] -fn gglwe_switching_key_compressed_encrypt_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ksk: usize = 54; - let digits: usize = k_ksk / basek; - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - println!( - "test_gglwe_switching_key_compressed_encrypt_sk digits: {} ranks: ({} {})", - di, rank_in, rank_out - ); - test_gglwe_switching_key_compressed_encrypt_sk(&module, basek, k_ksk, di, rank_in, rank_out); - }); - }); - }); -} - -#[test] -fn gglwe_switching_key_keyswitch() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in_s0s1| { - (1..4).for_each(|rank_out_s0s1| { - (1..4).for_each(|rank_out_s1s2| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - println!( - "test_gglwe_switching_key_keyswitch digits: {} ranks: ({},{},{})", - di, rank_in_s0s1, rank_out_s0s1, rank_out_s1s2 - ); - let k_out: usize = k_ksk; // Better capture noise. - test_gglwe_switching_key_keyswitch( - &module, - basek, - k_out, - k_in, - k_ksk, - di, - rank_in_s0s1, - rank_out_s0s1, - rank_out_s1s2, - ); - }) - }) - }); - }); -} - -#[test] -fn gglwe_switching_key_keyswitch_inplace() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank_in_s0s1| { - (1..4).for_each(|rank_out_s0s1| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!( - "test_gglwe_switching_key_keyswitch_inplace digits: {} ranks: ({},{})", - di, rank_in_s0s1, rank_out_s0s1 - ); - test_gglwe_switching_key_keyswitch_inplace(&module, basek, k_ct, k_ksk, di, rank_in_s0s1, rank_out_s0s1); - }); - }); - }); -} - -#[test] -fn gglwe_switching_key_external_product() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; - println!( - "test_gglwe_switching_key_external_product digits: {} ranks: ({} {})", - di, rank_in, rank_out - ); - let k_out: usize = k_in; // Better capture noise. - test_gglwe_switching_key_external_product(&module, basek, k_out, k_in, k_ggsw, di, rank_in, rank_out); - }); - }); - }); -} - -#[test] -fn gglwe_switching_key_external_product_inplace() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; - println!( - "test_gglwe_switching_key_external_product_inplace digits: {} ranks: ({} {})", - di, rank_in, rank_out - ); - test_gglwe_switching_key_external_product_inplace(&module, basek, k_ct, k_ggsw, di, rank_in, rank_out); - }); - }); - }); -} - -#[test] -fn gglwe_automorphisk_key_encrypt_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k: usize = 60; - let digits: usize = k.div_ceil(basek) - 1; - (1..4).for_each(|rank| { - (2..digits + 1).for_each(|di| { - println!( - "test_gglwe_automorphisk_key_encrypt_sk digits: {} rank: {}", - di, rank - ); - test_gglwe_automorphisk_key_encrypt_sk(&module, basek, k, di, rank); - }); - }); -} - -#[test] -fn gglwe_automorphisk_key_compressed_encrypt_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k: usize = 60; - let digits: usize = k.div_ceil(basek) - 1; - (1..4).for_each(|rank| { - (2..digits + 1).for_each(|di| { - println!( - "test_gglwe_automorphisk_key_compressed_encrypt_sk digits: {} rank: {}", - di, rank - ); - test_gglwe_automorphisk_key_compressed_encrypt_sk(&module, basek, k, di, rank); - }); - }); -} - -#[test] -fn gglwe_automorphism_key_automorphism() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 60; - let k_out: usize = 40; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (2..digits + 1).for_each(|di| { - println!( - "test_gglwe_automorphism_key_automorphism: {} rank: {}", - di, rank - ); - let k_apply: usize = (digits + di) * basek; - test_gglwe_automorphism_key_automorphism(&module, -1, 5, basek, di, k_in, k_out, k_apply, rank); - }); - }); -} - -#[test] -fn gglwe_automorphism_key_automorphism_inplace() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (2..digits + 1).for_each(|di| { - println!( - "test_gglwe_automorphism_key_automorphism_inplace: {} rank: {}", - di, rank - ); - let k_apply: usize = (digits + di) * basek; - test_gglwe_automorphism_key_automorphism_inplace(&module, -1, 5, basek, di, k_in, k_apply, rank); - }); - }); -} - -#[test] -fn glwe_tensor_key_encrypt_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - (1..4).for_each(|rank| { - println!("test_glwe_tensor_key_encrypt_sk rank: {}", rank); - test_glwe_tensor_key_encrypt_sk(&module, 16, 54, rank); - }); -} - -#[test] -fn glwe_tensor_key_compressed_encrypt_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - (1..4).for_each(|rank| { - println!("test_glwe_tensor_key_compressed_encrypt_sk rank: {}", rank); - test_glwe_tensor_key_compressed_encrypt_sk(&module, 16, 54, rank); - }); -} diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/ggws.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/ggws.rs deleted file mode 100644 index 7d2b0a7..0000000 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/ggws.rs +++ /dev/null @@ -1,148 +0,0 @@ -use poulpy_backend::cpu_spqlios::FFT64; -use poulpy_hal::{api::ModuleNew, layouts::Module}; - -use crate::tests::generics::{ - automorphism::{test_ggsw_automorphism, test_ggsw_automorphism_inplace}, - encryption::{test_ggsw_compressed_encrypt_sk, test_ggsw_encrypt_sk}, - external_product::{test_ggsw_external_product, test_ggsw_external_product_inplace}, - keyswitch::{test_ggsw_keyswitch, test_ggsw_keyswitch_inplace}, -}; - -#[test] -fn ggsw_encrypt_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 54; - let digits: usize = k_ct / basek; - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - println!("test_ggsw_encrypt_sk digits: {} rank: {}", di, rank); - test_ggsw_encrypt_sk(&module, basek, k_ct, di, rank); - }); - }); -} - -#[test] -fn ggsw_compressed_encrypt_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 54; - let digits: usize = k_ct / basek; - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - println!( - "test_ggsw_compressed_encrypt_sk digits: {} rank: {}", - di, rank - ); - test_ggsw_compressed_encrypt_sk(&module, basek, k_ct, di, rank); - }); - }); -} - -#[test] -fn ggsw_keyswitch() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 54; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_tsk: usize = k_ksk; - println!("test_ggsw_keyswitch digits: {} rank: {}", di, rank); - let k_out: usize = k_ksk; // Better capture noise. - test_ggsw_keyswitch(&module, basek, k_out, k_in, k_ksk, k_tsk, di, rank); - }); - }); -} - -#[test] -fn ggsw_keyswitch_inplace() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 54; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - let k_tsk: usize = k_ksk; - println!("test_ggsw_keyswitch_inplace digits: {} rank: {}", di, rank); - test_ggsw_keyswitch_inplace(&module, basek, k_ct, k_ksk, k_tsk, di, rank); - }); - }); -} - -#[test] -fn ggsw_automorphism() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 54; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_tsk: usize = k_ksk; - println!("test_ggsw_automorphism rank: {}", rank); - let k_out: usize = k_ksk; // Better capture noise. - test_ggsw_automorphism(-5, &module, basek, k_out, k_in, k_ksk, k_tsk, di, rank); - }); - }); -} - -#[test] -fn ggsw_automorphism_inplace() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 54; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - let k_tsk: usize = k_ksk; - println!("test_ggsw_automorphism_inplace rank: {}", rank); - test_ggsw_automorphism_inplace(-5, &module, basek, k_ct, k_ksk, k_tsk, di, rank); - }); - }); -} - -#[test] -fn ggsw_external_product() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; - println!("test external_product digits: {} ranks: {}", di, rank); - let k_out: usize = k_in; // Better capture noise. - test_ggsw_external_product(&module, basek, k_in, k_out, k_ggsw, di, rank); - }); - }); -} - -#[test] -fn ggsw_external_product_inplace() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; - println!( - "test_ggsw_external_product_inplace digits: {} rank: {}", - di, rank - ); - test_ggsw_external_product_inplace(&module, basek, k_ct, k_ggsw, di, rank); - }); - }); -} diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/glwe.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/glwe.rs deleted file mode 100644 index 0736b40..0000000 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/glwe.rs +++ /dev/null @@ -1,177 +0,0 @@ -use poulpy_backend::cpu_spqlios::FFT64; -use poulpy_hal::{api::ModuleNew, layouts::Module}; - -use crate::tests::generics::{ - automorphism::{test_glwe_automorphism, test_glwe_automorphism_inplace}, - encryption::{test_glwe_compressed_encrypt_sk, test_glwe_encrypt_pk, test_glwe_encrypt_sk, test_glwe_encrypt_zero_sk}, - external_product::{test_glwe_external_product, test_glwe_external_product_inplace}, - keyswitch::{test_glwe_keyswitch, test_glwe_keyswitch_inplace}, - test_glwe_packing, test_glwe_trace_inplace, -}; - -#[test] -fn glwe_encrypt_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - (1..4).for_each(|rank| { - println!("test_glwe_encrypt_sk rank: {}", rank); - test_glwe_encrypt_sk(&module, 8, 54, 30, rank); - }); -} - -#[test] -fn glwe_compressed_encrypt_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - (1..4).for_each(|rank| { - println!("test_glwe_compressed_encrypt_sk rank: {}", rank); - test_glwe_compressed_encrypt_sk(&module, 8, 54, 30, rank); - }); -} - -#[test] -fn glwe_encrypt_zero_sk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - (1..4).for_each(|rank| { - println!("test_glwe_encrypt_zero_sk rank: {}", rank); - test_glwe_encrypt_zero_sk(&module, 8, 64, rank); - }); -} - -#[test] -fn glwe_encrypt_pk() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - (1..4).for_each(|rank| { - println!("test_glwe_encrypt_pk rank: {}", rank); - test_glwe_encrypt_pk(&module, 8, 64, 64, rank) - }); -} - -#[test] -fn glwe_keyswitch() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank_in| { - (1..4).for_each(|rank_out| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_out: usize = k_ksk; // better capture noise - println!( - "test_glwe_keyswitch digits: {} rank_in: {} rank_out: {}", - di, rank_in, rank_out - ); - test_glwe_keyswitch(&module, basek, k_out, k_in, k_ksk, di, rank_in, rank_out); - }) - }); - }); -} - -#[test] -fn glwe_keyswitch_inplace() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 45; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!("test_glwe_keyswitch_inplace digits: {} rank: {}", di, rank); - test_glwe_keyswitch_inplace(&module, basek, k_ct, k_ksk, di, rank); - }); - }); -} - -#[test] -fn glwe_automorphism() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 60; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_in + basek * di; - let k_out: usize = k_ksk; // Better capture noise. - println!("test_glwe_automorphism digits: {} rank: {}", di, rank); - test_glwe_automorphism(&module, basek, -5, k_out, k_in, k_ksk, di, rank); - }) - }); -} - -#[test] -fn glwe_automorphism_inplace() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ksk: usize = k_ct + basek * di; - println!( - "test_glwe_automorphism_inplace digits: {} rank: {}", - di, rank - ); - test_glwe_automorphism_inplace(&module, basek, -5, k_ct, k_ksk, di, rank); - }); - }); -} - -#[test] -fn glwe_external_product() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_in: usize = 45; - let digits: usize = k_in.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_in + basek * di; - let k_out: usize = k_ggsw; // Better capture noise - println!("test_glwe_external_product digits: {} rank: {}", di, rank); - test_glwe_external_product(&module, basek, k_out, k_in, k_ggsw, di, rank); - }); - }); -} - -#[test] -fn glwe_external_product_inplace() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - let basek: usize = 12; - let k_ct: usize = 60; - let digits: usize = k_ct.div_ceil(basek); - (1..4).for_each(|rank| { - (1..digits + 1).for_each(|di| { - let k_ggsw: usize = k_ct + basek * di; - println!( - "test_glwe_external_product_inplace digits: {} rank: {}", - di, rank - ); - test_glwe_external_product_inplace(&module, basek, k_ct, k_ggsw, di, rank); - }); - }); -} - -#[test] -fn glwe_trace_inplace() { - let log_n: usize = 8; - let module: Module = Module::::new(1 << log_n); - (1..4).for_each(|rank| { - println!("test_glwe_trace_inplace rank: {}", rank); - test_glwe_trace_inplace(&module, 8, 54, rank); - }); -} - -#[test] -fn glwe_packing() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - test_glwe_packing(&module); -} diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/lwe.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/lwe.rs deleted file mode 100644 index 5213d0c..0000000 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/lwe.rs +++ /dev/null @@ -1,24 +0,0 @@ -use crate::tests::generics::{keyswitch::test_lwe_keyswitch, test_glwe_to_lwe, test_lwe_to_glwe}; -use poulpy_backend::cpu_spqlios::FFT64; -use poulpy_hal::{api::ModuleNew, layouts::Module}; - -#[test] -fn lwe_to_glwe() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - test_lwe_to_glwe(&module) -} - -#[test] -fn glwe_to_lwe() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - test_glwe_to_lwe(&module) -} - -#[test] -fn lwe_keyswitch() { - let log_n: usize = 5; - let module: Module = Module::::new(1 << log_n); - test_lwe_keyswitch(&module) -} diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/mod.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/mod.rs deleted file mode 100644 index 444225a..0000000 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/fft64/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod gglwe; -mod ggws; -mod glwe; -mod lwe; diff --git a/poulpy-core/src/tests/implementation/cpu_spqlios/mod.rs b/poulpy-core/src/tests/implementation/cpu_spqlios/mod.rs deleted file mode 100644 index aebaafb..0000000 --- a/poulpy-core/src/tests/implementation/cpu_spqlios/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod fft64; diff --git a/poulpy-core/src/tests/implementation/mod.rs b/poulpy-core/src/tests/implementation/mod.rs deleted file mode 100644 index f2bc1d4..0000000 --- a/poulpy-core/src/tests/implementation/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod cpu_spqlios; diff --git a/poulpy-core/src/tests/mod.rs b/poulpy-core/src/tests/mod.rs index 4120448..d031f55 100644 --- a/poulpy-core/src/tests/mod.rs +++ b/poulpy-core/src/tests/mod.rs @@ -1,7 +1,181 @@ -pub mod generics; - -#[cfg(test)] -mod implementation; +pub mod test_suite; #[cfg(test)] mod serialization; + +#[allow(unused_imports)] +use poulpy_hal::backend_test_suite; + +#[cfg(test)] +backend_test_suite!( + mod cpu_spqlios, + backend = poulpy_backend::cpu_spqlios::FFT64Spqlios, + size = 1<<8, + tests = { + // GLWE Encryption + glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, + glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, + glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, + glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, + // GLWE Keyswitch + glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, + glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, + // GLWE Automorphism + glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, + glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, + // GLWE External Product + glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, + glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, + // GLWE Trace + glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, + glwe_packing => crate::tests::test_suite::test_glwe_packing, + // GGLWE Encryption + gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, + gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, + gglwe_automorphisk_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_encrypt_sk, + gglwe_automorphisk_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_compressed_encrypt_sk, + gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, + gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, + // GGLWE Keyswitching + gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, + gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, + // GGLWE External Product + gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, + gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, + // GGLWE Automorphism + gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, + gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, + // GGSW Encryption + ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, + ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, + // GGSW Keyswitching + ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, + ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, + // GGSW External Product + ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, + ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, + // GGSW Automorphism + ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, + ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, + // LWE + lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, + glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, + lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, + } +); + +#[cfg(test)] +backend_test_suite!( + mod cpu_ref, + backend = poulpy_backend::cpu_fft64_ref::FFT64Ref, + size = 1<<8, + tests = { + // GLWE Encryption + glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, + glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, + glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, + glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, + // GLWE Keyswitch + glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, + glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, + // GLWE Automorphism + glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, + glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, + // GLWE External Product + glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, + glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, + // GLWE Trace + glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, + glwe_packing => crate::tests::test_suite::test_glwe_packing, + // GGLWE Encryption + gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, + gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, + gglwe_automorphisk_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_encrypt_sk, + gglwe_automorphisk_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_compressed_encrypt_sk, + gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, + gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, + // GGLWE Keyswitching + gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, + gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, + // GGLWE External Product + gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, + gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, + // GGLWE Automorphism + gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, + gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, + // GGSW Encryption + ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, + ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, + // GGSW Keyswitching + ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, + ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, + // GGSW External Product + ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, + ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, + // GGSW Automorphism + ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, + ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, + // LWE + lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, + glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, + lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, + } +); + +#[cfg(test)] +backend_test_suite!( + mod cpu_avx, + backend = poulpy_backend::cpu_fft64_avx::FFT64Avx, + size = 1<<8, + tests = { + // GLWE Encryption + glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, + glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, + glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, + glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, + // GLWE Keyswitch + glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, + glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, + // GLWE Automorphism + glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, + glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, + // GLWE External Product + glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, + glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, + // GLWE Trace + glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, + glwe_packing => crate::tests::test_suite::test_glwe_packing, + // GGLWE Encryption + gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, + gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, + gglwe_automorphisk_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_encrypt_sk, + gglwe_automorphisk_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphisk_key_compressed_encrypt_sk, + gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, + gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, + // GGLWE Keyswitching + gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, + gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, + // GGLWE External Product + gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, + gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, + // GGLWE Automorphism + gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, + gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, + // GGSW Encryption + ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, + ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, + // GGSW Keyswitching + ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, + ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, + // GGSW External Product + ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, + ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, + // GGSW Automorphism + ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, + ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, + // LWE + lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, + glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, + lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, + } +); diff --git a/poulpy-core/src/tests/serialization.rs b/poulpy-core/src/tests/serialization.rs index 14f8177..afc4c5a 100644 --- a/poulpy-core/src/tests/serialization.rs +++ b/poulpy-core/src/tests/serialization.rs @@ -1,4 +1,4 @@ -use poulpy_hal::tests::serialization::test_reader_writer_interface; +use poulpy_hal::test_suite::serialization::test_reader_writer_interface; use crate::layouts::{ GGLWEAutomorphismKey, GGLWECiphertext, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GLWECiphertext, @@ -21,7 +21,7 @@ const DIGITS: usize = 1; #[test] fn glwe_serialization() { let original: GLWECiphertext> = GLWECiphertext::alloc(N_GLWE, BASEK, K, RANK); - poulpy_hal::tests::serialization::test_reader_writer_interface(original); + poulpy_hal::test_suite::serialization::test_reader_writer_interface(original); } #[test] diff --git a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs new file mode 100644 index 0000000..6349445 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs @@ -0,0 +1,370 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWEAutomorphismKey, GLWEPlaintext, GLWESecret, Infos, + prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, + }, + noise::log2_std_noise_gglwe_product, +}; + +#[allow(clippy::too_many_arguments)] +pub fn test_gglwe_automorphism_key_automorphism(module: &Module) +where + Module: VecZnxDftAllocBytes + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxAutomorphism + + VecZnxAutomorphismInplace + + SvpPPolAllocBytes + + VecZnxDftAllocBytes + + VecZnxNormalizeTmpBytes + + VmpPMatAlloc + + VmpPrepare + + SvpPrepare + + SvpApplyDftToDftInplace + + VecZnxAddScalarInplace + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + VecZnxSwitchRing + + SvpPPolAlloc + + VecZnxBigAddInplace + + VecZnxSubScalarInplace, + B: Backend + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxDftImpl + + TakeVecZnxImpl + + TakeSvpPPolImpl + + TakeVecZnxBigImpl, +{ + let basek: usize = 12; + let k_in: usize = 60; + let k_out: usize = 40; + let digits: usize = k_in.div_ceil(basek); + let p0 = -1; + let p1 = -5; + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_apply: usize = (digits + di) * basek; + + let n: usize = module.n(); + let digits_in: usize = 1; + + let rows_in: usize = k_in / (basek * di); + let rows_out: usize = k_out / (basek * di); + let rows_apply: usize = k_in.div_ceil(basek * di); + + let mut auto_key_in: GGLWEAutomorphismKey> = + GGLWEAutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_out: GGLWEAutomorphismKey> = + GGLWEAutomorphismKey::alloc(n, basek, k_out, rows_out, digits_in, rank); + let mut auto_key_apply: GGLWEAutomorphismKey> = + GGLWEAutomorphismKey::alloc(n, basek, k_apply, rows_apply, di, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_apply, rank) + | GGLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_in, k_apply, di, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + // gglwe_{s1}(s0) = s0 -> s1 + auto_key_in.encrypt_sk( + module, + p0, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + auto_key_apply.encrypt_sk( + module, + p1, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_apply, rows_apply, di, rank); + + auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key_out.automorphism( + module, + &auto_key_in, + &auto_key_apply_prepared, + scratch.borrow(), + ); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_out); + + let mut sk_auto: GLWESecret> = GLWESecret::alloc(n, rank); + sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + (0..rank).for_each(|i| { + module.vec_znx_automorphism( + module.galois_element_inv(p0 * p1), + &mut sk_auto.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + + let sk_auto_dft: GLWESecretPrepared, B> = sk_auto.prepare_alloc(module, scratch.borrow()); + + (0..auto_key_out.rank_in()).for_each(|col_i| { + (0..auto_key_out.rows()).for_each(|row_i| { + auto_key_out + .at(row_i, col_i) + .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); + + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk.data, + col_i, + ); + + let noise_have: f64 = pt.data.std(basek, 0).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + n as f64, + basek * di, + 0.5, + 0.5, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + k_out, + k_apply, + ); + + assert!( + noise_have < noise_want + 0.5, + "{} {}", + noise_have, + noise_want + ); + }); + }); + }); + }); +} + +#[allow(clippy::too_many_arguments)] +pub fn test_gglwe_automorphism_key_automorphism_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxDftAllocBytes + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxAutomorphism + + VecZnxSwitchRing + + VecZnxAddScalarInplace + + VecZnxAutomorphism + + VecZnxAutomorphismInplace + + VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes + + VecZnxSubScalarInplace + + VecZnxCopy + + VmpPMatAlloc + + VmpPrepare, + B: Backend + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxDftImpl + + TakeVecZnxImpl + + TakeSvpPPolImpl + + TakeVecZnxBigImpl, +{ + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = k_in.div_ceil(basek); + let p0: i64 = -1; + let p1: i64 = -5; + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + println!( + "test_gglwe_automorphism_key_automorphism_inplace: {} rank: {}", + di, rank + ); + let k_apply: usize = (digits + di) * basek; + + let n: usize = module.n(); + let digits_in: usize = 1; + + let rows_in: usize = k_in / (basek * di); + let rows_apply: usize = k_in.div_ceil(basek * di); + + let mut auto_key: GGLWEAutomorphismKey> = + GGLWEAutomorphismKey::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut auto_key_apply: GGLWEAutomorphismKey> = + GGLWEAutomorphismKey::alloc(n, basek, k_apply, rows_apply, di, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_apply, rank) + | GGLWEAutomorphismKey::automorphism_inplace_scratch_space(module, basek, k_in, k_apply, di, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + // gglwe_{s1}(s0) = s0 -> s1 + auto_key.encrypt_sk( + module, + p0, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + auto_key_apply.encrypt_sk( + module, + p1, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_apply, rows_apply, di, rank); + + auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key.automorphism_inplace(module, &auto_key_apply_prepared, scratch.borrow()); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); + + let mut sk_auto: GLWESecret> = GLWESecret::alloc(n, rank); + sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + + (0..rank).for_each(|i| { + module.vec_znx_automorphism( + module.galois_element_inv(p0 * p1), + &mut sk_auto.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + + let sk_auto_dft: GLWESecretPrepared, B> = sk_auto.prepare_alloc(module, scratch.borrow()); + + (0..auto_key.rank_in()).for_each(|col_i| { + (0..auto_key.rows()).for_each(|row_i| { + auto_key + .at(row_i, col_i) + .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace( + &mut pt.data, + 0, + (digits_in - 1) + row_i * digits_in, + &sk.data, + col_i, + ); + + let noise_have: f64 = pt.data.std(basek, 0).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + n as f64, + basek * di, + 0.5, + 0.5, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + k_in, + k_apply, + ); + + assert!( + noise_have < noise_want + 0.5, + "{} {}", + noise_have, + noise_want + ); + }); + }); + }); + }); +} diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs new file mode 100644 index 0000000..a342cd3 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -0,0 +1,330 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScalarZnx, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWEAutomorphismKey, GGLWETensorKey, GGSWCiphertext, GLWESecret, + prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, + }, + noise::noise_ggsw_keyswitch, +}; + +#[allow(clippy::too_many_arguments)] +pub fn test_ggsw_automorphism(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes + + VecZnxBigAlloc + + VecZnxDftAlloc + + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyTmpA + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxAddScalarInplace + + VecZnxCopy + + VecZnxSubABInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftCopy + + VecZnxDftAddInplace + + VecZnxFillUniform + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpApplyDftToDft + + VecZnxSwitchRing + + VecZnxAutomorphismInplace + + VecZnxAutomorphism, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k_in: usize = 54; + let digits: usize = k_in.div_ceil(basek); + let p: i64 = -5; + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_tsk: usize = k_ksk; + let k_out: usize = k_ksk; // Better capture noise. + + let n: usize = module.n(); + let rows: usize = k_in.div_ceil(basek * di); + let rows_in: usize = k_in.div_euclid(basek * di); + + let digits_in: usize = 1; + + let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows_in, digits_in, rank); + let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_tsk, rows, di, rank); + let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, di, rank); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::automorphism_scratch_space(module, basek, k_out, k_in, k_ksk, di, k_tsk, di, rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + auto_key.encrypt_sk( + module, + p, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + tensor_key.encrypt_sk( + module, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); + + ct_in.encrypt_sk( + module, + &pt_scalar, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, di, rank); + auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); + + let mut tsk_prepared: GGLWETensorKeyPrepared, B> = + GGLWETensorKeyPrepared::alloc(module, basek, k_tsk, rows, di, rank); + tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); + + ct_out.automorphism( + module, + &ct_in, + &auto_key_prepared, + &tsk_prepared, + scratch.borrow(), + ); + + module.vec_znx_automorphism_inplace(p, &mut pt_scalar.as_vec_znx_mut(), 0, scratch.borrow()); + + let max_noise = |col_j: usize| -> f64 { + noise_ggsw_keyswitch( + n as f64, + basek * di, + col_j, + var_xs, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + k_in, + k_ksk, + k_tsk, + ) + 0.5 + }; + + ct_out.assert_noise(module, &sk_prepared, &pt_scalar, max_noise); + }); + }); +} + +#[allow(clippy::too_many_arguments)] +pub fn test_ggsw_automorphism_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes + + VecZnxBigAlloc + + VecZnxDftAlloc + + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyTmpA + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxAddScalarInplace + + VecZnxCopy + + VecZnxSubABInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigAddSmallInplace + + VecZnxDftCopy + + VecZnxDftAddInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + VecZnxFillUniform + + SvpApplyDftToDft + + VecZnxSwitchRing + + VecZnxAutomorphismInplace + + VecZnxAutomorphism, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k_ct: usize = 54; + let digits: usize = k_ct.div_ceil(basek); + let p = -1; + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + let k_tsk: usize = k_ksk; + + let n: usize = module.n(); + let rows: usize = k_ct.div_ceil(di * basek); + let rows_in: usize = k_ct.div_euclid(basek * di); + let digits_in: usize = 1; + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows_in, digits_in, rank); + let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_tsk, rows, di, rank); + let mut auto_key: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, di, rank); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::automorphism_inplace_scratch_space(module, basek, k_ct, k_ksk, di, k_tsk, di, rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + auto_key.encrypt_sk( + module, + p, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + tensor_key.encrypt_sk( + module, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); + + ct.encrypt_sk( + module, + &pt_scalar, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, di, rank); + auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); + + let mut tsk_prepared: GGLWETensorKeyPrepared, B> = + GGLWETensorKeyPrepared::alloc(module, basek, k_tsk, rows, di, rank); + tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); + + ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow()); + + module.vec_znx_automorphism_inplace(p, &mut pt_scalar.as_vec_znx_mut(), 0, scratch.borrow()); + + let max_noise = |col_j: usize| -> f64 { + noise_ggsw_keyswitch( + n as f64, + basek * di, + col_j, + var_xs, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + k_ct, + k_ksk, + k_tsk, + ) + 0.5 + }; + + ct.assert_noise(module, &sk_prepared, &pt_scalar, max_noise); + }); + }); +} diff --git a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs new file mode 100644 index 0000000..7726fd5 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs @@ -0,0 +1,275 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, + prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, Prepare, PrepareAlloc}, + }, + noise::log2_std_noise_gglwe_product, +}; + +#[allow(clippy::too_many_arguments)] +pub fn test_glwe_automorphism(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxAutomorphism + + VecZnxSwitchRing + + VecZnxAddScalarInplace + + VecZnxAutomorphismInplace + + VmpPMatAlloc + + VmpPrepare, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = k_in.div_ceil(basek); + let p: i64 = -5; + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_out: usize = k_ksk; // Better capture noise. + + let n: usize = module.n(); + let rows: usize = k_in.div_ceil(basek * digits); + + let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) + | GLWECiphertext::automorphism_scratch_space( + module, + basek, + ct_out.k(), + ct_in.k(), + autokey.k(), + digits, + rank, + ), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + autokey.encrypt_sk( + module, + p, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + module, + &pt_want, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); + autokey_prepared.prepare(module, &autokey, scratch.borrow()); + + ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow()); + + let max_noise: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + k_in, + k_ksk, + ); + + module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); + + ct_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); + }) + }); +} + +#[allow(clippy::too_many_arguments)] +pub fn test_glwe_automorphism_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxAutomorphism + + VecZnxSwitchRing + + VecZnxAddScalarInplace + + VecZnxAutomorphismInplace + + VmpPMatAlloc + + VmpPrepare, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + let p = -5; + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + println!( + "test_glwe_automorphism_inplace digits: {} rank: {}", + di, rank + ); + + let n: usize = module.n(); + let rows: usize = k_ct.div_ceil(basek * digits); + + let mut autokey: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, digits, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::automorphism_inplace_scratch_space(module, basek, ct.k(), autokey.k(), digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + autokey.encrypt_sk( + module, + p, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct.encrypt_sk( + module, + &pt_want, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); + autokey_prepared.prepare(module, &autokey, scratch.borrow()); + + ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow()); + + let max_noise: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); + + ct.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); + }); + }); +} diff --git a/poulpy-core/src/tests/generics/automorphism/mod.rs b/poulpy-core/src/tests/test_suite/automorphism/mod.rs similarity index 100% rename from poulpy-core/src/tests/generics/automorphism/mod.rs rename to poulpy-core/src/tests/test_suite/automorphism/mod.rs diff --git a/poulpy-core/src/tests/generics/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs similarity index 90% rename from poulpy-core/src/tests/generics/conversion.rs rename to poulpy-core/src/tests/test_suite/conversion.rs index 44c935c..5d4dcbb 100644 --- a/poulpy-core/src/tests/generics/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, - VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, - VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, + VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, }, layouts::{Backend, Module, ScratchOwned, ZnxView}, oep::{ @@ -25,9 +25,9 @@ pub fn test_lwe_to_glwe(module: &Module) where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxFillUniform + VecZnxSubABInplace + VecZnxAddInplace @@ -49,8 +49,8 @@ where + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree - + VecZnxAutomorphismInplace + + VecZnxSwitchRing + + VecZnxAutomorphismInplace + ZnNormalizeInplace + ZnFillUniform + ZnAddNormal, @@ -130,9 +130,9 @@ pub fn test_glwe_to_lwe(module: &Module) where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxFillUniform + VecZnxSubABInplace + VecZnxAddInplace @@ -154,8 +154,8 @@ where + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree - + VecZnxAutomorphismInplace + + VecZnxSwitchRing + + VecZnxAutomorphismInplace + ZnNormalizeInplace, B: Backend + TakeVecZnxDftImpl diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs new file mode 100644 index 0000000..c263faf --- /dev/null +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs @@ -0,0 +1,211 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWEAutomorphismKey, GLWESecret, + compressed::{Decompress, GGLWEAutomorphismKeyCompressed}, + prepared::{GLWESecretPrepared, PrepareAlloc}, + }, +}; + +pub fn test_gglwe_automorphisk_key_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigAddSmallInplace + + VecZnxAutomorphism + + VecZnxSwitchRing + + VecZnxAddScalarInplace + + VecZnxAutomorphismInplace + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxSubScalarInplace + + VecZnxCopy + + VmpPMatAlloc + + VmpPrepare, + B: Backend + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxDftImpl + + TakeVecZnxImpl + + TakeSvpPPolImpl + + TakeVecZnxBigImpl, +{ + let basek: usize = 12; + let k_ksk: usize = 60; + let digits: usize = k_ksk.div_ceil(basek) - 1; + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let n: usize = module.n(); + let rows: usize = (k_ksk - di * basek) / (di * basek); + + let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, di, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( + module, basek, k_ksk, rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let p = -5; + + atk.encrypt_sk( + module, + p, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut sk_out: GLWESecret> = sk.clone(); + (0..atk.rank()).for_each(|i| { + module.vec_znx_automorphism( + module.galois_element_inv(p), + &mut sk_out.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + atk.key + .key + .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); + }); + }); +} + +pub fn test_gglwe_automorphisk_key_compressed_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VmpApplyDftToDftTmpBytes + + VecZnxBigNormalizeTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigAddSmallInplace + + VecZnxAutomorphism + + VecZnxSwitchRing + + VecZnxAddScalarInplace + + VecZnxAutomorphismInplace + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxSubScalarInplace + + VecZnxCopy + + VmpPMatAlloc + + VmpPrepare, + B: Backend + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxDftImpl + + TakeVecZnxImpl + + TakeSvpPPolImpl + + TakeVecZnxBigImpl, +{ + let basek: usize = 12; + let k_ksk: usize = 60; + let digits: usize = k_ksk.div_ceil(basek) - 1; + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let n: usize = module.n(); + let rows: usize = (k_ksk - di * basek) / (di * basek); + + let mut atk_compressed: GGLWEAutomorphismKeyCompressed> = + GGLWEAutomorphismKeyCompressed::alloc(n, basek, k_ksk, rows, di, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( + module, basek, k_ksk, rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let p = -5; + + let seed_xa: [u8; 32] = [1u8; 32]; + + atk_compressed.encrypt_sk(module, p, &sk, seed_xa, &mut source_xe, scratch.borrow()); + + let mut sk_out: GLWESecret> = sk.clone(); + (0..atk_compressed.rank()).for_each(|i| { + module.vec_znx_automorphism( + module.galois_element_inv(p), + &mut sk_out.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + let sk_out_prepared = sk_out.prepare_alloc(module, scratch.borrow()); + + let mut atk: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_ksk, rows, di, rank); + atk.decompress(module, &atk_compressed); + + atk.key + .key + .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); + }); + }); +} diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs new file mode 100644 index 0000000..ca4fb02 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs @@ -0,0 +1,191 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxSubScalarInplace, VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWESwitchingKey, GLWESecret, + compressed::{Decompress, GGLWESwitchingKeyCompressed}, + prepared::{GLWESecretPrepared, PrepareAlloc}, + }, +}; + +pub fn test_gglwe_switching_key_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAddSmallInplace + + VecZnxSwitchRing + + VecZnxAddScalarInplace + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxSubScalarInplace + + VecZnxCopy + + VmpPMatAlloc + + VmpPrepare, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k_ksk: usize = 54; + let digits: usize = k_ksk / basek; + (1..3).for_each(|rank_in| { + (1..3).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let n: usize = module.n(); + let rows: usize = (k_ksk - di * basek) / (di * basek); + + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank_in, rank_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( + module, basek, k_ksk, rank_in, rank_out, + )); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + ksk.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ksk.key + .assert_noise(module, &sk_out_prepared, &sk_in.data, SIGMA); + }); + }); + }); +} + +pub fn test_gglwe_switching_key_compressed_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAddSmallInplace + + VecZnxSwitchRing + + VecZnxAddScalarInplace + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxSubScalarInplace + + VecZnxCopy + + VmpPMatAlloc + + VmpPrepare, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k_ksk: usize = 54; + let digits: usize = k_ksk / basek; + (1..3).for_each(|rank_in| { + (1..3).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let n: usize = module.n(); + let rows: usize = (k_ksk - di * basek) / (di * basek); + + let mut ksk_compressed: GGLWESwitchingKeyCompressed> = + GGLWESwitchingKeyCompressed::alloc(n, basek, k_ksk, rows, di, rank_in, rank_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space( + module, basek, k_ksk, rank_in, rank_out, + )); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + let seed_xa = [1u8; 32]; + + ksk_compressed.encrypt_sk( + module, + &sk_in, + &sk_out, + seed_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank_in, rank_out); + ksk.decompress(module, &ksk_compressed); + + ksk.key + .assert_noise(module, &sk_out_prepared, &sk_in.data, SIGMA); + }); + }); + }); +} diff --git a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs new file mode 100644 index 0000000..1739ffb --- /dev/null +++ b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs @@ -0,0 +1,196 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, + VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScalarZnx, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGSWCiphertext, GLWESecret, + compressed::{Decompress, GGSWCiphertextCompressed}, + prepared::{GLWESecretPrepared, PrepareAlloc}, + }, +}; + +pub fn test_ggsw_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAddSmallInplace + + VecZnxAddScalarInplace + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxCopy + + VmpPMatAlloc + + VmpPrepare + + VecZnxBigAlloc + + VecZnxDftAlloc + + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyTmpA, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k: usize = 54; + let digits: usize = k / basek; + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let n: usize = module.n(); + let rows: usize = (k - di * basek) / (di * basek); + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k, rows, di, rank); + + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertext::encrypt_sk_scratch_space( + module, basek, k, rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + ct.encrypt_sk( + module, + &pt_scalar, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; + + ct.assert_noise(module, &sk_prepared, &pt_scalar, noise_f); + }); + }); +} + +pub fn test_ggsw_compressed_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAddSmallInplace + + VecZnxAddScalarInplace + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxCopy + + VmpPMatAlloc + + VmpPrepare + + VecZnxBigAlloc + + VecZnxDftAlloc + + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyTmpA, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k: usize = 54; + let digits: usize = k / basek; + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let n: usize = module.n(); + let rows: usize = (k - di * basek) / (di * basek); + + let mut ct_compressed: GGSWCiphertextCompressed> = + GGSWCiphertextCompressed::alloc(n, basek, k, rows, di, rank); + + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertextCompressed::encrypt_sk_scratch_space( + module, basek, k, rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let seed_xa: [u8; 32] = [1u8; 32]; + + ct_compressed.encrypt_sk( + module, + &pt_scalar, + &sk_prepared, + seed_xa, + &mut source_xe, + scratch.borrow(), + ); + + let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k, rows, di, rank); + ct.decompress(module, &ct_compressed); + + ct.assert_noise(module, &sk_prepared, &pt_scalar, noise_f); + }); + }); +} diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs new file mode 100644 index 0000000..9f77baa --- /dev/null +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs @@ -0,0 +1,395 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, + VecZnxSubABInplace, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos, + compressed::{Decompress, GLWECiphertextCompressed}, + prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + }, + operations::GLWEOperations, +}; + +pub fn test_glwe_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + SvpPPolAllocBytes + + SvpPrepare + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + + VecZnxBigAddNormal + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes + + VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 8; + let k_ct: usize = 54; + let k_pt: usize = 30; + + for rank in 1..3 { + let n = module.n(); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + + ct.encrypt_sk( + module, + &pt_want, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); + + pt_want.sub_inplace_ab(module, &pt_have); + + let noise_have: f64 = pt_want.data.std(basek, 0) * (ct.k() as f64).exp2(); + let noise_want: f64 = SIGMA; + + assert!(noise_have <= noise_want + 0.2); + } +} + +pub fn test_glwe_compressed_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + SvpPPolAllocBytes + + SvpPrepare + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + + VecZnxBigAddNormal + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes + + VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + VecZnxCopy, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 8; + let k_ct: usize = 54; + let k_pt: usize = 30; + + for rank in 1..3 { + let n = module.n(); + let mut ct_compressed: GLWECiphertextCompressed> = GLWECiphertextCompressed::alloc(n, basek, k_ct, rank); + + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_pt); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertextCompressed::encrypt_sk_scratch_space(module, basek, k_ct) + | GLWECiphertext::decrypt_scratch_space(module, basek, k_ct), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + + let seed_xa: [u8; 32] = [1u8; 32]; + + ct_compressed.encrypt_sk( + module, + &pt_want, + &sk_prepared, + seed_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + ct.decompress(module, &ct_compressed); + + ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); + + pt_want.sub_inplace_ab(module, &pt_have); + + let noise_have: f64 = pt_want.data.std(basek, 0) * (ct.k() as f64).exp2(); + let noise_want: f64 = SIGMA; + + assert!( + noise_have <= noise_want + 0.2, + "{} <= {}", + noise_have, + noise_want + 0.2 + ); + } +} + +pub fn test_glwe_encrypt_zero_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxDftAllocBytes + + VecZnxBigAllocBytes + + SvpPPolAllocBytes + + SvpPrepare + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + + VecZnxBigAddNormal + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize + + VecZnxNormalizeTmpBytes + + VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 8; + let k_ct: usize = 54; + + for rank in 1..3 { + let n = module.n(); + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::decrypt_scratch_space(module, basek, k_ct) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + + ct.encrypt_zero_sk( + module, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); + + assert!((SIGMA - pt.data.std(basek, 0) * (k_ct as f64).exp2()) <= 0.2); + } +} + +pub fn test_glwe_encrypt_pk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAddSmallInplace + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxCopy + + VecZnxDftAlloc + + SvpApplyDftToDft + + VecZnxBigAddNormal, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 8; + let k_ct: usize = 54; + let k_pk: usize = 54; + + for rank in 1..3 { + let n: usize = module.n(); + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) + | GLWECiphertext::encrypt_pk_scratch_space(module, basek, k_pk), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(n, basek, k_pk, rank); + pk.generate_from_sk(module, &sk_prepared, &mut source_xa, &mut source_xe); + + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + + let pk_prepared: GLWEPublicKeyPrepared, B> = pk.prepare_alloc(module, scratch.borrow()); + + ct.encrypt_pk( + module, + &pt_want, + &pk_prepared, + &mut source_xu, + &mut source_xe, + scratch.borrow(), + ); + + ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); + + pt_want.sub_inplace_ab(module, &pt_have); + + let noise_have: f64 = pt_want.data.std(basek, 0).log2(); + let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64); + + assert!( + noise_have <= noise_want + 0.2, + "{} {}", + noise_have, + noise_want + ); + } +} diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs new file mode 100644 index 0000000..d653b17 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs @@ -0,0 +1,250 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, + }, + layouts::{Backend, Module, ScratchOwned, VecZnxDft}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWETensorKey, GLWEPlaintext, GLWESecret, Infos, + compressed::{Decompress, GGLWETensorKeyCompressed}, + prepared::{GLWESecretPrepared, PrepareAlloc}, + }, +}; + +pub fn test_gglwe_tensor_key_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAddSmallInplace + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxCopy + + VecZnxDftAlloc + + SvpApplyDftToDft + + VecZnxBigAlloc + + VecZnxIdftApplyTmpA + + VecZnxAddScalarInplace + + VecZnxSwitchRing + + VecZnxSubScalarInplace, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 8; + let k: usize = 54; + + (1..3).for_each(|rank| { + let n: usize = module.n(); + let rows: usize = k / basek; + + let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k, rows, 1, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKey::encrypt_sk_scratch_space( + module, + basek, + tensor_key.k(), + rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + tensor_key.encrypt_sk( + module, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); + + let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); + let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); + let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); + + (0..rank).for_each(|i| { + module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); + }); + + (0..rank).for_each(|i| { + (0..rank).for_each(|j| { + module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); + module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + module.vec_znx_big_normalize( + basek, + &mut sk_ij.data.as_vec_znx_mut(), + 0, + &sk_ij_big, + 0, + scratch.borrow(), + ); + (0..tensor_key.rank_in()).for_each(|col_i| { + (0..tensor_key.rows()).for_each(|row_i| { + tensor_key + .at(i, j) + .at(row_i, col_i) + .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); + + module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); + + let std_pt: f64 = pt.data.std(basek, 0) * (k as f64).exp2(); + assert!((SIGMA - std_pt).abs() <= 0.5, "{} {}", SIGMA, std_pt); + }); + }); + }); + }); + }); +} + +pub fn test_gglwe_tensor_key_compressed_encrypt_sk(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAddSmallInplace + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxCopy + + VecZnxDftAlloc + + SvpApplyDftToDft + + VecZnxBigAlloc + + VecZnxIdftApplyTmpA + + VecZnxAddScalarInplace + + VecZnxSwitchRing + + VecZnxSubScalarInplace, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek = 8; + let k = 54; + (1..3).for_each(|rank| { + let n: usize = module.n(); + let rows: usize = k / basek; + + let mut tensor_key_compressed: GGLWETensorKeyCompressed> = + GGLWETensorKeyCompressed::alloc(n, basek, k, rows, 1, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKeyCompressed::encrypt_sk_scratch_space( + module, + basek, + tensor_key_compressed.k(), + rank, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let seed_xa: [u8; 32] = [1u8; 32]; + + tensor_key_compressed.encrypt_sk(module, &sk, seed_xa, &mut source_xe, scratch.borrow()); + + let mut tensor_key: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k, rows, 1, rank); + tensor_key.decompress(module, &tensor_key_compressed); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); + + let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); + let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); + let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); + let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); + + (0..rank).for_each(|i| { + module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); + }); + + (0..rank).for_each(|i| { + (0..rank).for_each(|j| { + module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); + module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + module.vec_znx_big_normalize( + basek, + &mut sk_ij.data.as_vec_znx_mut(), + 0, + &sk_ij_big, + 0, + scratch.borrow(), + ); + (0..tensor_key.rank_in()).for_each(|col_i| { + (0..tensor_key.rows()).for_each(|row_i| { + tensor_key + .at(i, j) + .at(row_i, col_i) + .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); + + module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, col_i); + + let std_pt: f64 = pt.data.std(basek, 0) * (k as f64).exp2(); + assert!((SIGMA - std_pt).abs() <= 0.5, "{} {}", SIGMA, std_pt); + }); + }); + }); + }); + }); +} diff --git a/poulpy-core/src/tests/generics/encryption/mod.rs b/poulpy-core/src/tests/test_suite/encryption/mod.rs similarity index 100% rename from poulpy-core/src/tests/generics/encryption/mod.rs rename to poulpy-core/src/tests/test_suite/encryption/mod.rs diff --git a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs new file mode 100644 index 0000000..86f5c28 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs @@ -0,0 +1,323 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, + VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned, ZnxViewMut}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWESwitchingKey, GGSWCiphertext, GLWESecret, + prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + }, + noise::noise_ggsw_product, +}; + +#[allow(clippy::too_many_arguments)] +pub fn test_gglwe_switching_key_external_product(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxSwitchRing + + VecZnxAddScalarInplace + + VecZnxSubScalarInplace + + VecZnxCopy + + VmpPMatAlloc + + VecZnxRotateInplace + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VmpPrepare, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = k_in.div_ceil(basek); + (1..3).for_each(|rank_in| { + (1..3).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + let k_out: usize = k_in; // Better capture noise. + + let n: usize = module.n(); + let rows: usize = k_in.div_ceil(basek * di); + let digits_in: usize = 1; + + let mut ct_gglwe_in: GGLWESwitchingKey> = + GGLWESwitchingKey::alloc(n, basek, k_in, rows, digits_in, rank_in, rank_out); + let mut ct_gglwe_out: GGLWESwitchingKey> = + GGLWESwitchingKey::alloc(n, basek, k_out, rows, digits_in, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, di, rank_out); + + let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_in, rank_in, rank_out) + | GGLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, di, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), + ); + + let r: usize = 1; + + pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} + + let var_xs: f64 = 0.5; + + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_in.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + module, + &pt_rgsw, + &sk_out_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ct_rgsw_prepared: GGSWCiphertextPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); + + // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) + ct_gglwe_out.external_product(module, &ct_gglwe_in, &ct_rgsw_prepared, scratch.borrow()); + + (0..rank_in).for_each(|i| { + module.vec_znx_rotate_inplace( + r as i64, + &mut sk_in.data.as_vec_znx_mut(), + i, + scratch.borrow(), + ); // * X^{r} + }); + + let var_gct_err_lhs: f64 = SIGMA * SIGMA; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / n as f64; // X^{k} + let var_a0_err: f64 = SIGMA * SIGMA; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise: f64 = noise_ggsw_product( + n as f64, + basek * di, + var_xs, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_out as f64, + k_in, + k_ggsw, + ); + + ct_gglwe_out + .key + .assert_noise(module, &sk_out_prepared, &sk_in.data, max_noise + 0.5); + }); + }); + }); +} + +#[allow(clippy::too_many_arguments)] +pub fn test_gglwe_switching_key_external_product_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxSwitchRing + + VecZnxAddScalarInplace + + VecZnxSubScalarInplace + + VecZnxCopy + + VmpPMatAlloc + + VecZnxRotateInplace + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VmpPrepare, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..3).for_each(|rank_in| { + (1..3).for_each(|rank_out| { + (1..digits).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + + let n: usize = module.n(); + let rows: usize = k_ct.div_ceil(basek * di); + + let digits_in: usize = 1; + + let mut ct_gglwe: GGLWESwitchingKey> = + GGLWESwitchingKey::alloc(n, basek, k_ct, rows, digits_in, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, di, rank_out); + + let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ct, rank_in, rank_out) + | GGLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, di, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), + ); + + let r: usize = 1; + + pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} + + let var_xs: f64 = 0.5; + + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + module, + &pt_rgsw, + &sk_out_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ct_rgsw_prepared: GGSWCiphertextPrepared, B> = ct_rgsw.prepare_alloc(module, scratch.borrow()); + + // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) + ct_gglwe.external_product_inplace(module, &ct_rgsw_prepared, scratch.borrow()); + + (0..rank_in).for_each(|i| { + module.vec_znx_rotate_inplace( + r as i64, + &mut sk_in.data.as_vec_znx_mut(), + i, + scratch.borrow(), + ); // * X^{r} + }); + + let var_gct_err_lhs: f64 = SIGMA * SIGMA; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / n as f64; // X^{k} + let var_a0_err: f64 = SIGMA * SIGMA; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise: f64 = noise_ggsw_product( + n as f64, + basek * di, + var_xs, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_out as f64, + k_ct, + k_ggsw, + ); + + ct_gglwe + .key + .assert_noise(module, &sk_out_prepared, &sk_in.data, max_noise + 0.5); + }); + }); + }); +} diff --git a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs new file mode 100644 index 0000000..5fa02f5 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs @@ -0,0 +1,307 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAlloc, VecZnxDftAllocBytes, + VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToMut, ScratchOwned, ZnxViewMut}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGSWCiphertext, GLWESecret, + prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + }, + noise::noise_ggsw_product, +}; + +#[allow(clippy::too_many_arguments)] +pub fn test_ggsw_external_product(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxAddScalarInplace + + VecZnxCopy + + VmpPMatAlloc + + VecZnxRotateInplace + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VmpPrepare + + VecZnxBigAlloc + + VecZnxDftAlloc + + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyTmpA, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = k_in.div_ceil(basek); + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + println!("test external_product digits: {} ranks: {}", di, rank); + let k_out: usize = k_in; // Better capture noise. + + let n: usize = module.n(); + let rows: usize = k_in.div_ceil(basek * di); + let rows_in: usize = k_in.div_euclid(basek * di); + let digits_in: usize = 1; + + let mut ct_ggsw_lhs_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows_in, digits_in, rank); + let mut ct_ggsw_lhs_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows_in, digits_in, rank); + let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, di, rank); + let mut pt_ggsw_lhs: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_ggsw_rhs: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, di, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + ct_ggsw_rhs.encrypt_sk( + module, + &pt_ggsw_rhs, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct_ggsw_lhs_in.encrypt_sk( + module, + &pt_ggsw_lhs, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ct_ggsw_rhs.prepare_alloc(module, scratch.borrow()); + + ct_ggsw_lhs_out.external_product(module, &ct_ggsw_lhs_in, &ct_rhs_prepared, scratch.borrow()); + + module.vec_znx_rotate_inplace( + k as i64, + &mut pt_ggsw_lhs.as_vec_znx_mut(), + 0, + scratch.borrow(), + ); + + let var_gct_err_lhs: f64 = SIGMA * SIGMA; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / n as f64; // X^{k} + let var_a0_err: f64 = SIGMA * SIGMA; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise = |_col_j: usize| -> f64 { + noise_ggsw_product( + n as f64, + basek * di, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_in, + k_ggsw, + ) + 0.5 + }; + + ct_ggsw_lhs_out.assert_noise(module, &sk_prepared, &pt_ggsw_lhs, max_noise); + }); + }); +} + +#[allow(clippy::too_many_arguments)] +pub fn test_ggsw_external_product_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxAddScalarInplace + + VecZnxCopy + + VmpPMatAlloc + + VecZnxRotateInplace + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VmpPrepare + + VecZnxBigAlloc + + VecZnxDftAlloc + + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyTmpA, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl + + VecZnxDftAllocBytesImpl + + VecZnxBigAllocBytesImpl + + TakeSvpPPolImpl, +{ + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..3).for_each(|rank| { + (1..digits).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + + let n: usize = module.n(); + let rows: usize = k_ct.div_ceil(di * basek); + let rows_in: usize = k_ct.div_euclid(basek * di); + let digits_in: usize = 1; + + let mut ct_ggsw_lhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows_in, digits_in, rank); + let mut ct_ggsw_rhs: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, di, rank); + + let mut pt_ggsw_lhs: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_ggsw_rhs: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, di, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + ct_ggsw_rhs.encrypt_sk( + module, + &pt_ggsw_rhs, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct_ggsw_lhs.encrypt_sk( + module, + &pt_ggsw_lhs, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ct_rhs_prepared: GGSWCiphertextPrepared, B> = ct_ggsw_rhs.prepare_alloc(module, scratch.borrow()); + + ct_ggsw_lhs.external_product_inplace(module, &ct_rhs_prepared, scratch.borrow()); + + module.vec_znx_rotate_inplace( + k as i64, + &mut pt_ggsw_lhs.as_vec_znx_mut(), + 0, + scratch.borrow(), + ); + + let var_gct_err_lhs: f64 = SIGMA * SIGMA; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / n as f64; // X^{k} + let var_a0_err: f64 = SIGMA * SIGMA; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise = |_col_j: usize| -> f64 { + noise_ggsw_product( + n as f64, + basek * di, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct, + k_ggsw, + ) + 0.5 + }; + + ct_ggsw_lhs.assert_noise(module, &sk_prepared, &pt_ggsw_lhs, max_noise); + }); + }); +} diff --git a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs new file mode 100644 index 0000000..93842be --- /dev/null +++ b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs @@ -0,0 +1,295 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScalarZnx, ScratchOwned, ZnxViewMut}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, + prepared::{GGSWCiphertextPrepared, GLWESecretPrepared, PrepareAlloc}, + }, + noise::noise_ggsw_product, +}; + +#[allow(clippy::too_many_arguments)] +pub fn test_glwe_external_product(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxAddScalarInplace + + VecZnxRotateInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = k_in.div_ceil(basek); + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_in + basek * di; + let k_out: usize = k_ggsw; // Better capture noise + + let n: usize = module.n(); + let rows: usize = k_in.div_ceil(basek * digits); + + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank); + let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + + pt_want.data.at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe_in.k()) + | GLWECiphertext::external_product_scratch_space( + module, + basek, + ct_glwe_out.k(), + ct_glwe_in.k(), + ct_ggsw.k(), + digits, + rank, + ), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + ct_ggsw.encrypt_sk( + module, + &pt_rgsw, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct_glwe_in.encrypt_sk( + module, + &pt_want, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ct_ggsw.prepare_alloc(module, scratch.borrow()); + + ct_glwe_out.external_product(module, &ct_glwe_in, &ct_ggsw_prepared, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0, scratch.borrow()); + + let var_gct_err_lhs: f64 = SIGMA * SIGMA; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / n as f64; // X^{k} + let var_a0_err: f64 = SIGMA * SIGMA; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise: f64 = noise_ggsw_product( + n as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_in, + k_ggsw, + ); + + ct_glwe_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); + }); + }); +} + +#[allow(clippy::too_many_arguments)] +pub fn test_glwe_external_product_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxAddScalarInplace + + VecZnxRotateInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ggsw: usize = k_ct + basek * di; + + let n: usize = module.n(); + let rows: usize = k_ct.div_ceil(basek * digits); + + let mut ct_ggsw: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ggsw, rows, digits, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + + pt_want.data.at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) + | GLWECiphertext::external_product_inplace_scratch_space( + module, + basek, + ct_glwe.k(), + ct_ggsw.k(), + digits, + rank, + ), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + ct_ggsw.encrypt_sk( + module, + &pt_rgsw, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct_glwe.encrypt_sk( + module, + &pt_want, + &sk_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ct_ggsw_prepared: GGSWCiphertextPrepared, B> = ct_ggsw.prepare_alloc(module, scratch.borrow()); + + ct_glwe.external_product_inplace(module, &ct_ggsw_prepared, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0, scratch.borrow()); + + let var_gct_err_lhs: f64 = SIGMA * SIGMA; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / n as f64; // X^{k} + let var_a0_err: f64 = SIGMA * SIGMA; + let var_a1_err: f64 = 1f64 / 12f64; + + let max_noise: f64 = noise_ggsw_product( + n as f64, + basek * digits, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct, + k_ggsw, + ); + + ct_glwe.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); + }); + }); +} diff --git a/poulpy-core/src/tests/generics/external_product/mod.rs b/poulpy-core/src/tests/test_suite/external_product/mod.rs similarity index 100% rename from poulpy-core/src/tests/generics/external_product/mod.rs rename to poulpy-core/src/tests/test_suite/external_product/mod.rs diff --git a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs new file mode 100644 index 0000000..a0b35bc --- /dev/null +++ b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs @@ -0,0 +1,322 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, + VecZnxSubABInplace, VecZnxSubScalarInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWESwitchingKey, GLWESecret, + prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + }, + noise::log2_std_noise_gglwe_product, +}; + +#[allow(clippy::too_many_arguments)] +pub fn test_gglwe_switching_key_keyswitch(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxAddScalarInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigNormalizeTmpBytes + + VecZnxSwitchRing + + VecZnxSubScalarInplace, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_in: usize = 60; + let digits: usize = k_in.div_ceil(basek); + + (1..3).for_each(|rank_in_s0s1| { + (1..3).for_each(|rank_out_s0s1| { + (1..3).for_each(|rank_out_s1s2| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_out: usize = k_ksk; // Better capture noise. + + let n: usize = module.n(); + let rows: usize = k_in / basek; + let rows_apply: usize = k_in.div_ceil(basek * di); + let digits_in: usize = 1; + + let mut ct_gglwe_s0s1: GGLWESwitchingKey> = + GGLWESwitchingKey::alloc(n, basek, k_in, rows, digits_in, rank_in_s0s1, rank_out_s0s1); + let mut ct_gglwe_s1s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc( + n, + basek, + k_ksk, + rows_apply, + di, + rank_out_s0s1, + rank_out_s1s2, + ); + let mut ct_gglwe_s0s2: GGLWESwitchingKey> = GGLWESwitchingKey::alloc( + n, + basek, + k_out, + rows, + digits_in, + rank_in_s0s1, + rank_out_s1s2, + ); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( + module, + basek, + k_ksk, + rank_in_s0s1 | rank_out_s0s1, + rank_out_s0s1 | rank_out_s1s2, + )); + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_scratch_space( + module, + basek, + k_out, + k_in, + k_ksk, + di, + ct_gglwe_s1s2.rank_in(), + ct_gglwe_s1s2.rank_out(), + )); + + let mut sk0: GLWESecret> = GLWESecret::alloc(n, rank_in_s0s1); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1: GLWESecret> = GLWESecret::alloc(n, rank_out_s0s1); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2: GLWESecret> = GLWESecret::alloc(n, rank_out_s1s2); + sk2.fill_ternary_prob(0.5, &mut source_xs); + let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_s0s1.encrypt_sk( + module, + &sk0, + &sk1, + &mut source_xa, + &mut source_xe, + scratch_enc.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + ct_gglwe_s1s2.encrypt_sk( + module, + &sk1, + &sk2, + &mut source_xa, + &mut source_xe, + scratch_enc.borrow(), + ); + + let ct_gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = + ct_gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + ct_gglwe_s0s2.keyswitch( + module, + &ct_gglwe_s0s1, + &ct_gglwe_s1s2_prepared, + scratch_apply.borrow(), + ); + + let max_noise: f64 = log2_std_noise_gglwe_product( + n as f64, + basek * di, + 0.5, + 0.5, + 0f64, + SIGMA * SIGMA, + 0f64, + rank_out_s0s1 as f64, + k_in, + k_ksk, + ); + + ct_gglwe_s0s2 + .key + .assert_noise(module, &sk2_prepared, &sk0.data, max_noise + 0.5); + }); + }); + }); + }); +} + +#[allow(clippy::too_many_arguments)] +pub fn test_gglwe_switching_key_keyswitch_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxAddScalarInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigNormalizeTmpBytes + + VecZnxSwitchRing + + VecZnxSubScalarInplace, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_ct: usize = 60; + let digits: usize = k_ct.div_ceil(basek); + (1..3).for_each(|rank_in| { + (1..3).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + + let n: usize = module.n(); + let rows: usize = k_ct.div_ceil(basek * di); + let digits_in: usize = 1; + + let mut ct_gglwe_s0s1: GGLWESwitchingKey> = + GGLWESwitchingKey::alloc(n, basek, k_ct, rows, digits_in, rank_in, rank_out); + let mut ct_gglwe_s1s2: GGLWESwitchingKey> = + GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank_out, rank_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( + module, + basek, + k_ksk, + rank_in | rank_out, + rank_out, + )); + let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_inplace_scratch_space( + module, basek, k_ct, k_ksk, di, rank_out, + )); + + let var_xs: f64 = 0.5; + + let mut sk0: GLWESecret> = GLWESecret::alloc(n, rank_in); + sk0.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk1: GLWESecret> = GLWESecret::alloc(n, rank_out); + sk1.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk2: GLWESecret> = GLWESecret::alloc(n, rank_out); + sk2.fill_ternary_prob(var_xs, &mut source_xs); + let sk2_prepared: GLWESecretPrepared, B> = sk2.prepare_alloc(module, scratch_apply.borrow()); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_s0s1.encrypt_sk( + module, + &sk0, + &sk1, + &mut source_xa, + &mut source_xe, + scratch_enc.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + ct_gglwe_s1s2.encrypt_sk( + module, + &sk1, + &sk2, + &mut source_xa, + &mut source_xe, + scratch_enc.borrow(), + ); + + let ct_gglwe_s1s2_prepared: GGLWESwitchingKeyPrepared, B> = + ct_gglwe_s1s2.prepare_alloc(module, scratch_apply.borrow()); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + ct_gglwe_s0s1.keyswitch_inplace(module, &ct_gglwe_s1s2_prepared, scratch_apply.borrow()); + + let ct_gglwe_s0s2: GGLWESwitchingKey> = ct_gglwe_s0s1; + + let max_noise: f64 = log2_std_noise_gglwe_product( + n as f64, + basek * di, + var_xs, + var_xs, + 0f64, + SIGMA * SIGMA, + 0f64, + rank_out as f64, + k_ct, + k_ksk, + ); + + ct_gglwe_s0s2 + .key + .assert_noise(module, &sk2_prepared, &sk0.data, max_noise + 0.5); + }); + }); + }); +} diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs new file mode 100644 index 0000000..9b82ba9 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -0,0 +1,309 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAlloc, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, + VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScalarZnx, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GLWESecret, + prepared::{GGLWESwitchingKeyPrepared, GGLWETensorKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + }, + noise::noise_ggsw_keyswitch, +}; + +#[allow(clippy::too_many_arguments)] +pub fn test_ggsw_keyswitch(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxAddScalarInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigNormalizeTmpBytes + + VecZnxSwitchRing + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxDftCopy + + VecZnxDftAddInplace + + VecZnxBigAlloc + + VecZnxDftAlloc, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_in: usize = 54; + let digits: usize = k_in.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_tsk: usize = k_ksk; + let k_out: usize = k_ksk; // Better capture noise. + + let n: usize = module.n(); + let rows: usize = k_in.div_ceil(di * basek); + + let digits_in: usize = 1; + + let mut ct_in: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_in, rows, digits_in, rank); + let mut ct_out: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_out, rows, digits_in, rank); + let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_ksk, rows, di, rank); + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank, rank); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) + | GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) + | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, di, k_tsk, di, rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + ksk.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + tsk.encrypt_sk( + module, + &sk_out, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); + + ct_in.encrypt_sk( + module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); + + ct_out.keyswitch( + module, + &ct_in, + &ksk_prepared, + &tsk_prepared, + scratch.borrow(), + ); + + let max_noise = |col_j: usize| -> f64 { + noise_ggsw_keyswitch( + n as f64, + basek * di, + col_j, + var_xs, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + k_in, + k_ksk, + k_tsk, + ) + 0.5 + }; + + ct_out.assert_noise(module, &sk_out_prepared, &pt_scalar, max_noise); + }); + }); +} + +#[allow(clippy::too_many_arguments)] +pub fn test_ggsw_keyswitch_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxAddScalarInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigNormalizeTmpBytes + + VecZnxSwitchRing + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxDftCopy + + VecZnxDftAddInplace + + VecZnxBigAlloc + + VecZnxDftAlloc, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_ct: usize = 54; + let digits: usize = k_ct.div_ceil(basek); + (1..4).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + let k_tsk: usize = k_ksk; + + let n: usize = module.n(); + let rows: usize = k_ct.div_ceil(di * basek); + + let digits_in: usize = 1; + + let mut ct: GGSWCiphertext> = GGSWCiphertext::alloc(n, basek, k_ct, rows, digits_in, rank); + let mut tsk: GGLWETensorKey> = GGLWETensorKey::alloc(n, basek, k_tsk, rows, di, rank); + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, di, rank, rank); + let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) + | GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) + | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::keyswitch_inplace_scratch_space(module, basek, k_ct, k_ksk, di, k_tsk, di, rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + let sk_in_dft: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + ksk.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + tsk.encrypt_sk( + module, + &sk_out, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, n, &mut source_xs); + + ct.encrypt_sk( + module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + let tsk_prepared: GGLWETensorKeyPrepared, B> = tsk.prepare_alloc(module, scratch.borrow()); + + ct.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow()); + + let max_noise = |col_j: usize| -> f64 { + noise_ggsw_keyswitch( + n as f64, + basek * di, + col_j, + var_xs, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + k_ct, + k_ksk, + k_tsk, + ) + 0.5 + }; + + ct.assert_noise(module, &sk_out_prepared, &pt_scalar, max_noise); + }); + }); +} diff --git a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs new file mode 100644 index 0000000..456828b --- /dev/null +++ b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs @@ -0,0 +1,268 @@ +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, + VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, + VmpPrepare, + }, + layouts::{Backend, Module, ScratchOwned}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWESwitchingKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, + prepared::{GGLWESwitchingKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + }, + noise::log2_std_noise_gglwe_product, +}; + +#[allow(clippy::too_many_arguments)] +pub fn test_glwe_keyswitch(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxAddScalarInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigNormalizeTmpBytes + + VecZnxSwitchRing, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_in: usize = 45; + let digits: usize = k_in.div_ceil(basek); + + (1..3).for_each(|rank_in| { + (1..3).for_each(|rank_out| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_in + basek * di; + let k_out: usize = k_ksk; // better capture noise + + let n: usize = module.n(); + let rows: usize = k_in.div_ceil(basek * digits); + + let mut ksk: GGLWESwitchingKey> = + GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_in); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank_in, rank_out) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) + | GLWECiphertext::keyswitch_scratch_space( + module, + basek, + ct_out.k(), + ct_in.k(), + ksk.k(), + digits, + rank_in, + rank_out, + ), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + ksk.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + module, + &pt_want, + &sk_in_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + + ct_out.keyswitch(module, &ct_in, &ksk_prepared, scratch.borrow()); + + let max_noise: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + SIGMA * SIGMA, + 0f64, + rank_in as f64, + k_in, + k_ksk, + ); + + ct_out.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); + }) + }); + }); +} + +pub fn test_glwe_keyswitch_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxAddScalarInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigNormalizeTmpBytes + + VecZnxSwitchRing, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 12; + let k_ct: usize = 45; + let digits: usize = k_ct.div_ceil(basek); + + (1..3).for_each(|rank| { + (1..digits + 1).for_each(|di| { + let k_ksk: usize = k_ct + basek * di; + + let n: usize = module.n(); + let rows: usize = k_ct.div_ceil(basek * digits); + + let mut ksk: GGLWESwitchingKey> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank, rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) + | GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, ct_glwe.k(), ksk.k(), digits, rank), + ); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let sk_in_prepared: GLWESecretPrepared, B> = sk_in.prepare_alloc(module, scratch.borrow()); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let sk_out_prepared: GLWESecretPrepared, B> = sk_out.prepare_alloc(module, scratch.borrow()); + + ksk.encrypt_sk( + module, + &sk_in, + &sk_out, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + ct_glwe.encrypt_sk( + module, + &pt_want, + &sk_in_prepared, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let ksk_prepared: GGLWESwitchingKeyPrepared, B> = ksk.prepare_alloc(module, scratch.borrow()); + + ct_glwe.keyswitch_inplace(module, &ksk_prepared, scratch.borrow()); + + let max_noise: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek * digits, + 0.5, + 0.5, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + ct_glwe.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); + }); + }); +} diff --git a/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs similarity index 86% rename from poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs rename to poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs index bc86cc1..9432ee8 100644 --- a/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs @@ -1,11 +1,11 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, - VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, - VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, + VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, }, layouts::{Backend, Module, ScratchOwned, ZnxView}, oep::{ @@ -24,9 +24,9 @@ pub fn test_lwe_keyswitch(module: &Module) where Module: VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxFillUniform + VecZnxSubABInplace + VecZnxAddInplace @@ -48,8 +48,8 @@ where + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree - + VecZnxAutomorphismInplace + + VecZnxSwitchRing + + VecZnxAutomorphismInplace + ZnNormalizeInplace + ZnFillUniform + ZnAddNormal, diff --git a/poulpy-core/src/tests/generics/keyswitch/mod.rs b/poulpy-core/src/tests/test_suite/keyswitch/mod.rs similarity index 100% rename from poulpy-core/src/tests/generics/keyswitch/mod.rs rename to poulpy-core/src/tests/test_suite/keyswitch/mod.rs diff --git a/poulpy-core/src/tests/generics/mod.rs b/poulpy-core/src/tests/test_suite/mod.rs similarity index 100% rename from poulpy-core/src/tests/generics/mod.rs rename to poulpy-core/src/tests/test_suite/mod.rs diff --git a/poulpy-core/src/tests/generics/packing.rs b/poulpy-core/src/tests/test_suite/packing.rs similarity index 87% rename from poulpy-core/src/tests/generics/packing.rs rename to poulpy-core/src/tests/test_suite/packing.rs index 9f93326..09b5591 100644 --- a/poulpy-core/src/tests/generics/packing.rs +++ b/poulpy-core/src/tests/test_suite/packing.rs @@ -2,13 +2,13 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxFillUniform, - VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, - VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, + VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -33,13 +33,13 @@ where + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallBInplace + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxRotateInplace + + VecZnxRshInplace + + VecZnxRotateInplace + VecZnxBigNormalize - + DFT + + VecZnxDftApply + VecZnxRotate - + SvpApplyInplace - + IDFTConsume + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxFillUniform + VecZnxSubABInplace + VecZnxAddInplace @@ -61,8 +61,8 @@ where + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxBigNormalizeTmpBytes - + VecZnxSwithcDegree - + VecZnxAutomorphismInplace + + VecZnxSwitchRing + + VecZnxAutomorphismInplace + VecZnxCopy, B: Backend + TakeVecZnxDftImpl @@ -150,7 +150,7 @@ where scratch.borrow(), ); - pt.rotate_inplace(module, -(1 << log_batch)); // X^-batch * pt + pt.rotate_inplace(module, -(1 << log_batch), scratch.borrow()); // X^-batch * pt if reverse_bits_msb(i, log_n as u32).is_multiple_of(5) { packer.add(module, Some(&ct), &auto_keys, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/trace.rs b/poulpy-core/src/tests/test_suite/trace.rs new file mode 100644 index 0000000..b369626 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/trace.rs @@ -0,0 +1,172 @@ +use std::collections::HashMap; + +use poulpy_hal::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, + VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigAddInplace, + VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, + VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{Backend, Module, ScratchOwned, ZnxView, ZnxViewMut}, + oep::{ + ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeScalarZnxImpl, TakeSvpPPolImpl, + TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxImpl, + }, + source::Source, +}; + +use crate::{ + encryption::SIGMA, + layouts::{ + GGLWEAutomorphismKey, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, + prepared::{GGLWEAutomorphismKeyPrepared, GLWESecretPrepared, PrepareAlloc}, + }, + noise::var_noise_gglwe_product, +}; + +pub fn test_glwe_trace_inplace(module: &Module) +where + Module: VecZnxDftAllocBytes + + VecZnxAutomorphism + + VecZnxBigAutomorphismInplace + + VecZnxBigSubSmallBInplace + + VecZnxRshInplace + + VecZnxRotateInplace + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxFillUniform + + VecZnxSubABInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + SvpPPolAllocBytes + + SvpPPolAlloc + + VecZnxBigAllocBytes + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes + + VecZnxAddScalarInplace + + VmpPMatAlloc + + VmpPrepare + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxBigNormalizeTmpBytes + + VecZnxSwitchRing + + VecZnxCopy, + B: Backend + + TakeVecZnxDftImpl + + TakeVecZnxBigImpl + + TakeSvpPPolImpl + + ScratchOwnedAllocImpl + + ScratchOwnedBorrowImpl + + ScratchAvailableImpl + + TakeScalarZnxImpl + + TakeVecZnxImpl, +{ + let basek: usize = 8; + let k: usize = 54; + + (1..3).for_each(|rank| { + let n: usize = module.n(); + let k_autokey: usize = k + basek; + + let digits: usize = 1; + let rows: usize = k.div_ceil(basek * digits); + + let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_autokey, rank) + | GLWECiphertext::trace_inplace_scratch_space(module, basek, ct.k(), k_autokey, digits, rank), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let sk_dft: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); + + let mut data_want: Vec = vec![0i64; n]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); + + module.vec_znx_fill_uniform(basek, &mut pt_have.data, 0, &mut source_xa); + + ct.encrypt_sk( + module, + &pt_have, + &sk_dft, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut auto_keys: HashMap, B>> = HashMap::new(); + let gal_els: Vec = GLWECiphertext::trace_galois_elements(module); + let mut tmp: GGLWEAutomorphismKey> = GGLWEAutomorphismKey::alloc(n, basek, k_autokey, rows, digits, rank); + gal_els.iter().for_each(|gal_el| { + tmp.encrypt_sk( + module, + *gal_el, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + let atk_prepared: GGLWEAutomorphismKeyPrepared, B> = tmp.prepare_alloc(module, scratch.borrow()); + auto_keys.insert(*gal_el, atk_prepared); + }); + + ct.trace_inplace(module, 0, 5, &auto_keys, scratch.borrow()); + ct.trace_inplace(module, 5, module.log_n(), &auto_keys, scratch.borrow()); + + (0..pt_want.size()).for_each(|i| pt_want.data.at_mut(0, i)[0] = pt_have.data.at(0, i)[0]); + + ct.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0); + module.vec_znx_normalize_inplace(basek, &mut pt_want.data, 0, scratch.borrow()); + + let noise_have: f64 = pt_want.std().log2(); + + let mut noise_want: f64 = var_noise_gglwe_product( + n as f64, + basek, + 0.5, + 0.5, + 1.0 / 12.0, + SIGMA * SIGMA, + 0.0, + rank as f64, + k, + k_autokey, + ); + noise_want += SIGMA * SIGMA * (-2.0 * (k) as f64).exp2(); + noise_want += n as f64 * 1.0 / 12.0 * 0.5 * rank as f64 * (-2.0 * (k) as f64).exp2(); + noise_want = noise_want.sqrt().log2(); + + assert!( + (noise_have - noise_want).abs() < 1.0, + "{} > {}", + noise_have, + noise_want + ); + }); +} diff --git a/poulpy-hal/Cargo.toml b/poulpy-hal/Cargo.toml index 07d01f3..dda503a 100644 --- a/poulpy-hal/Cargo.toml +++ b/poulpy-hal/Cargo.toml @@ -17,7 +17,10 @@ rand = {workspace = true} rand_distr = {workspace = true} rand_core = {workspace = true} byteorder = {workspace = true} +once_cell = {workspace = true} rand_chacha = "0.9.0" +bytemuck = "1.23.2" + [build-dependencies] cmake = "0.1.54" diff --git a/poulpy-hal/src/api/svp_ppol.rs b/poulpy-hal/src/api/svp_ppol.rs index bc915ae..5a72367 100644 --- a/poulpy-hal/src/api/svp_ppol.rs +++ b/poulpy-hal/src/api/svp_ppol.rs @@ -1,4 +1,6 @@ -use crate::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}; +use crate::layouts::{ + Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, +}; /// Allocates as [crate::layouts::SvpPPol]. pub trait SvpPPolAlloc { @@ -25,8 +27,26 @@ pub trait SvpPrepare { } /// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`. -pub trait SvpApply { - fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) +pub trait SvpApplyDft { + fn svp_apply_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + C: VecZnxToRef; +} + +/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and stores the result on `res[res_col]`. +pub trait SvpApplyDftToDft { + fn svp_apply_dft_to_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + C: VecZnxDftToRef; +} + +/// Apply a scalar-vector product between `a[a_col]` and `b[b_col]` and adds the result on `res[res_col]`. +pub trait SvpApplyDftToDftAdd { + fn svp_apply_dft_to_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) where R: VecZnxDftToMut, A: SvpPPolToRef, @@ -34,8 +54,8 @@ pub trait SvpApply { } /// Apply a scalar-vector product between `res[res_col]` and `a[a_col]` and stores the result on `res[res_col]`. -pub trait SvpApplyInplace { - fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub trait SvpApplyDftToDftInplace { + fn svp_apply_dft_to_dft_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: SvpPPolToRef; diff --git a/poulpy-hal/src/api/vec_znx.rs b/poulpy-hal/src/api/vec_znx.rs index f69a7a3..0c8c7bc 100644 --- a/poulpy-hal/src/api/vec_znx.rs +++ b/poulpy-hal/src/api/vec_znx.rs @@ -1,5 +1,3 @@ -use rand_distr::Distribution; - use crate::{ layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, source::Source, @@ -42,6 +40,16 @@ pub trait VecZnxAddInplace { A: VecZnxToRef; } +pub trait VecZnxAddScalar { + /// Adds the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`. + #[allow(clippy::too_many_arguments)] + fn vec_znx_add_scalar(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef; +} + pub trait VecZnxAddScalarInplace { /// Adds the selected column of `a` on the selected column and limb of `res`. fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) @@ -79,6 +87,16 @@ pub trait VecZnxSubBAInplace { A: VecZnxToRef; } +pub trait VecZnxSubScalar { + /// Subtracts the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`. + #[allow(clippy::too_many_arguments)] + fn vec_znx_sub_scalar(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef; +} + pub trait VecZnxSubScalarInplace { /// Subtracts the selected column of `a` on the selected column and limb of `res`. fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) @@ -102,31 +120,61 @@ pub trait VecZnxNegateInplace { A: VecZnxToMut; } -pub trait VecZnxLshInplace { +pub trait VecZnxLshTmpBytes { + fn vec_znx_lsh_tmp_bytes(&self) -> usize; +} + +pub trait VecZnxLsh { /// Left shift by k bits all columns of `a`. - fn vec_znx_lsh_inplace(&self, basek: usize, k: usize, a: &mut A) + #[allow(clippy::too_many_arguments)] + fn vec_znx_lsh(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxRshTmpBytes { + fn vec_znx_rsh_tmp_bytes(&self) -> usize; +} + +pub trait VecZnxRsh { + /// Right shift by k bits all columns of `a`. + #[allow(clippy::too_many_arguments)] + fn vec_znx_rsh(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxLshInplace { + /// Left shift by k bits all columns of `a`. + fn vec_znx_lsh_inplace(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut; } -pub trait VecZnxRshInplace { +pub trait VecZnxRshInplace { /// Right shift by k bits all columns of `a`. - fn vec_znx_rsh_inplace(&self, basek: usize, k: usize, a: &mut A) + fn vec_znx_rsh_inplace(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut; } pub trait VecZnxRotate { /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. - fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_rotate(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef; } -pub trait VecZnxRotateInplace { +pub trait VecZnxRotateInplaceTmpBytes { + fn vec_znx_rotate_inplace_tmp_bytes(&self) -> usize; +} + +pub trait VecZnxRotateInplace { /// Multiplies the selected column of `a` by X^k. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn vec_znx_rotate_inplace(&self, p: i64, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut; } @@ -139,54 +187,70 @@ pub trait VecZnxAutomorphism { A: VecZnxToRef; } -pub trait VecZnxAutomorphismInplace { +pub trait VecZnxAutomorphismInplaceTmpBytes { + fn vec_znx_automorphism_inplace_tmp_bytes(&self) -> usize; +} + +pub trait VecZnxAutomorphismInplace { /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn vec_znx_automorphism_inplace(&self, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch) where - A: VecZnxToMut; + R: VecZnxToMut; } pub trait VecZnxMulXpMinusOne { - fn vec_znx_mul_xp_minus_one(&self, p: i64, r: &mut R, r_col: usize, a: &A, a_col: usize) + fn vec_znx_mul_xp_minus_one(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef; } -pub trait VecZnxMulXpMinusOneInplace { - fn vec_znx_mul_xp_minus_one_inplace(&self, p: i64, r: &mut R, r_col: usize) +pub trait VecZnxMulXpMinusOneInplaceTmpBytes { + fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(&self) -> usize; +} + +pub trait VecZnxMulXpMinusOneInplace { + fn vec_znx_mul_xp_minus_one_inplace(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch) where R: VecZnxToMut; } -pub trait VecZnxSplit { +pub trait VecZnxSplitRingTmpBytes { + fn vec_znx_split_ring_tmp_bytes(&self) -> usize; +} + +pub trait VecZnxSplitRing { /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// /// # Panics /// /// This method requires that all [crate::layouts::VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_split_ring(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where R: VecZnxToMut, A: VecZnxToRef; } -pub trait VecZnxMerge { +pub trait VecZnxMergeRingsTmpBytes { + fn vec_znx_merge_rings_tmp_bytes(&self) -> usize; +} + +pub trait VecZnxMergeRings { /// Merges the subrings of the selected column of `a` into the selected column of `res`. /// /// # Panics /// /// This method requires that all [crate::layouts::VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize) + fn vec_znx_merge_rings(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize, scratch: &mut Scratch) where R: VecZnxToMut, A: VecZnxToRef; } -pub trait VecZnxSwithcDegree { - fn vec_znx_switch_degree(&self, res: &mut R, res_col: usize, a: &A, col_a: usize) +pub trait VecZnxSwitchRing { + fn vec_znx_switch_ring(&self, res: &mut R, res_col: usize, a: &A, col_a: usize) where R: VecZnxToMut, A: VecZnxToRef; @@ -201,42 +265,11 @@ pub trait VecZnxCopy { pub trait VecZnxFillUniform { /// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\] - fn vec_znx_fill_uniform(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + fn vec_znx_fill_uniform(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut; } -#[allow(clippy::too_many_arguments)] -pub trait VecZnxFillDistF64 { - fn vec_znx_fill_dist_f64>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: VecZnxToMut; -} - -#[allow(clippy::too_many_arguments)] -pub trait VecZnxAddDistF64 { - /// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\]. - fn vec_znx_add_dist_f64>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: VecZnxToMut; -} - #[allow(clippy::too_many_arguments)] pub trait VecZnxFillNormal { fn vec_znx_fill_normal( diff --git a/poulpy-hal/src/api/vec_znx_big.rs b/poulpy-hal/src/api/vec_znx_big.rs index cefaaa4..09ff5b3 100644 --- a/poulpy-hal/src/api/vec_znx_big.rs +++ b/poulpy-hal/src/api/vec_znx_big.rs @@ -1,10 +1,15 @@ -use rand_distr::Distribution; - use crate::{ layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef}, source::Source, }; +pub trait VecZnxBigFromSmall { + fn vec_znx_big_from_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; +} + /// Allocates as [crate::layouts::VecZnxBig]. pub trait VecZnxBigAlloc { fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned; @@ -45,48 +50,6 @@ pub trait VecZnxBigAddNormal { ); } -#[allow(clippy::too_many_arguments)] -pub trait VecZnxBigFillNormal { - fn vec_znx_big_fill_normal>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ); -} - -#[allow(clippy::too_many_arguments)] -pub trait VecZnxBigFillDistF64 { - fn vec_znx_big_fill_dist_f64, D: Distribution>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ); -} - -#[allow(clippy::too_many_arguments)] -pub trait VecZnxBigAddDistF64 { - fn vec_znx_big_add_dist_f64, D: Distribution>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ); -} - pub trait VecZnxBigAdd { /// Adds `a` to `b` and stores the result on `c`. fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) @@ -180,10 +143,17 @@ pub trait VecZnxBigSubSmallBInplace { A: VecZnxToRef; } -pub trait VecZnxBigNegateInplace { - fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) +pub trait VecZnxBigNegate { + fn vec_znx_big_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where - A: VecZnxBigToMut; + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + +pub trait VecZnxBigNegateInplace { + fn vec_znx_big_negate_inplace(&self, res: &mut R, res_col: usize) + where + R: VecZnxBigToMut; } pub trait VecZnxBigNormalizeTmpBytes { @@ -204,9 +174,13 @@ pub trait VecZnxBigNormalize { A: VecZnxBigToRef; } +pub trait VecZnxBigAutomorphismInplaceTmpBytes { + fn vec_znx_big_automorphism_inplace_tmp_bytes(&self) -> usize; +} + pub trait VecZnxBigAutomorphism { /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_big_automorphism(&self, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxBigToRef; @@ -214,7 +188,7 @@ pub trait VecZnxBigAutomorphism { pub trait VecZnxBigAutomorphismInplace { /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn vec_znx_big_automorphism_inplace(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch) where - A: VecZnxBigToMut; + R: VecZnxBigToMut; } diff --git a/poulpy-hal/src/api/vec_znx_dft.rs b/poulpy-hal/src/api/vec_znx_dft.rs index 334db87..588ec53 100644 --- a/poulpy-hal/src/api/vec_znx_dft.rs +++ b/poulpy-hal/src/api/vec_znx_dft.rs @@ -14,33 +14,33 @@ pub trait VecZnxDftAllocBytes { fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize; } -pub trait DFT { - fn dft(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub trait VecZnxDftApply { + fn vec_znx_dft_apply(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxToRef; } -pub trait VecZnxIDFTTmpBytes { - fn vec_znx_idft_tmp_bytes(&self) -> usize; +pub trait VecZnxIdftApplyTmpBytes { + fn vec_znx_idft_apply_tmp_bytes(&self) -> usize; } -pub trait IDFT { - fn idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) +pub trait VecZnxIdftApply { + fn vec_znx_idft_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where R: VecZnxBigToMut, A: VecZnxDftToRef; } -pub trait IDFTTmpA { - fn idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) +pub trait VecZnxIdftApplyTmpA { + fn vec_znx_idft_apply_tmpa(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxDftToMut; } -pub trait IDFTConsume { - fn vec_znx_idft_consume(&self, a: VecZnxDft) -> VecZnxBig +pub trait VecZnxIdftApplyConsume { + fn vec_znx_idft_apply_consume(&self, a: VecZnxDft) -> VecZnxBig where VecZnxDft: VecZnxDftToMut; } diff --git a/poulpy-hal/src/api/vmp_pmat.rs b/poulpy-hal/src/api/vmp_pmat.rs index 972ceb1..3d0e248 100644 --- a/poulpy-hal/src/api/vmp_pmat.rs +++ b/poulpy-hal/src/api/vmp_pmat.rs @@ -1,4 +1,6 @@ -use crate::layouts::{Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef}; +use crate::layouts::{ + Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, +}; pub trait VmpPMatAlloc { fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; @@ -17,12 +19,33 @@ pub trait VmpPrepareTmpBytes { } pub trait VmpPrepare { - fn vmp_prepare(&self, res: &mut R, a: &A, scratch: &mut Scratch) + fn vmp_prepare(&self, pmat: &mut R, mat: &A, scratch: &mut Scratch) where R: VmpPMatToMut, A: MatZnxToRef; } +#[allow(clippy::too_many_arguments)] +pub trait VmpApplyDftTmpBytes { + fn vmp_apply_dft_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; +} + +pub trait VmpApplyDft { + fn vmp_apply_dft(&self, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxToRef, + C: VmpPMatToRef; +} + #[allow(clippy::too_many_arguments)] pub trait VmpApplyDftToDftTmpBytes { fn vmp_apply_dft_to_dft_tmp_bytes( @@ -61,7 +84,7 @@ pub trait VmpApplyDftToDft { /// * `a`: the left operand [crate::layouts::VecZnxDft] of the vector matrix product. /// * `b`: the right operand [crate::layouts::VmpPMat] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [VmpApplyTmpBytes::vmp_apply_tmp_bytes]. - fn vmp_apply_dft_to_dft(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) + fn vmp_apply_dft_to_dft(&self, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch) where R: VecZnxDftToMut, A: VecZnxDftToRef, @@ -82,7 +105,7 @@ pub trait VmpApplyDftToDftAddTmpBytes { } pub trait VmpApplyDftToDftAdd { - fn vmp_apply_dft_to_dft_add(&self, res: &mut R, a: &A, b: &C, scale: usize, scratch: &mut Scratch) + fn vmp_apply_dft_to_dft_add(&self, res: &mut R, a: &A, b: &C, limb_offset: usize, scratch: &mut Scratch) where R: VecZnxDftToMut, A: VecZnxDftToRef, diff --git a/poulpy-hal/src/api/zn.rs b/poulpy-hal/src/api/zn.rs index 60e8ac9..e7c4ef5 100644 --- a/poulpy-hal/src/api/zn.rs +++ b/poulpy-hal/src/api/zn.rs @@ -1,57 +1,29 @@ -use rand_distr::Distribution; - use crate::{ layouts::{Backend, Scratch, ZnToMut}, + reference::zn::zn_normalize_tmp_bytes, source::Source, }; +pub trait ZnNormalizeTmpBytes { + fn zn_normalize_tmp_bytes(&self, n: usize) -> usize { + zn_normalize_tmp_bytes(n) + } +} + pub trait ZnNormalizeInplace { /// Normalizes the selected column of `a`. - fn zn_normalize_inplace(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn zn_normalize_inplace(&self, n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) where - A: ZnToMut; + R: ZnToMut; } pub trait ZnFillUniform { /// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\] - fn zn_fill_uniform(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + fn zn_fill_uniform(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut; } -#[allow(clippy::too_many_arguments)] -pub trait ZnFillDistF64 { - fn zn_fill_dist_f64>( - &self, - n: usize, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -pub trait ZnAddDistF64 { - /// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\]. - fn zn_add_dist_f64>( - &self, - n: usize, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: ZnToMut; -} - #[allow(clippy::too_many_arguments)] pub trait ZnFillNormal { fn zn_fill_normal( diff --git a/poulpy-hal/src/bench_suite/mod.rs b/poulpy-hal/src/bench_suite/mod.rs new file mode 100644 index 0000000..57b60b4 --- /dev/null +++ b/poulpy-hal/src/bench_suite/mod.rs @@ -0,0 +1,5 @@ +pub mod svp; +pub mod vec_znx; +pub mod vec_znx_big; +pub mod vec_znx_dft; +pub mod vmp; diff --git a/poulpy-hal/src/bench_suite/svp.rs b/poulpy-hal/src/bench_suite/svp.rs new file mode 100644 index 0000000..7007d9a --- /dev/null +++ b/poulpy-hal/src/bench_suite/svp.rs @@ -0,0 +1,237 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; +use rand::RngCore; + +use crate::{ + api::{ + ModuleNew, SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare, + VecZnxDftAlloc, + }, + layouts::{Backend, DataViewMut, FillUniform, Module, ScalarZnx, SvpPPol, VecZnx, VecZnxDft}, + source::Source, +}; + +pub fn bench_svp_prepare(c: &mut Criterion, label: &str) +where + Module: SvpPrepare + SvpPPolAlloc + ModuleNew, + B: Backend, +{ + let group_name: String = format!("svp_prepare::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(log_n: usize) -> impl FnMut() + where + Module: SvpPrepare + SvpPPolAlloc + ModuleNew, + B: Backend, + { + let module: Module = Module::::new(1 << log_n); + + let cols: usize = 2; + + let mut svp: SvpPPol, B> = module.svp_ppol_alloc(cols); + let mut a: ScalarZnx> = ScalarZnx::alloc(module.n(), cols); + let mut source = Source::new([0u8; 32]); + a.fill_uniform(50, &mut source); + + move || { + module.svp_prepare(&mut svp, 0, &a, 0); + black_box(()); + } + } + + for log_n in [10, 11, 12, 13, 14] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}", 1 << log_n)); + let mut runner = runner::(log_n); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_svp_apply_dft(c: &mut Criterion, label: &str) +where + Module: SvpApplyDft + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, + B: Backend, +{ + let group_name: String = format!("svp_apply_dft::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: SvpApplyDft + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, + B: Backend, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut svp: SvpPPol, B> = module.svp_ppol_alloc(cols); + let mut res: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + + let mut source = Source::new([0u8; 32]); + + source.fill_bytes(svp.data_mut()); + source.fill_bytes(res.data_mut()); + source.fill_bytes(a.data_mut()); + + move || { + for j in 0..cols { + module.svp_apply_dft(&mut res, j, &svp, j, &a, j); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_svp_apply_dft_to_dft(c: &mut Criterion, label: &str) +where + Module: SvpApplyDftToDft + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, + B: Backend, +{ + let group_name: String = format!("svp_apply_dft_to_dft::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: SvpApplyDftToDft + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, + B: Backend, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut svp: SvpPPol, B> = module.svp_ppol_alloc(cols); + let mut res: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut a: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + + let mut source = Source::new([0u8; 32]); + + source.fill_bytes(svp.data_mut()); + source.fill_bytes(res.data_mut()); + source.fill_bytes(a.data_mut()); + + move || { + for j in 0..cols { + module.svp_apply_dft_to_dft(&mut res, j, &svp, j, &a, j); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_svp_apply_dft_to_dft_add(c: &mut Criterion, label: &str) +where + Module: SvpApplyDftToDftAdd + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, + B: Backend, +{ + let group_name: String = format!("svp_apply_dft_to_dft_add::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: SvpApplyDftToDftAdd + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, + B: Backend, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut svp: SvpPPol, B> = module.svp_ppol_alloc(cols); + let mut res: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut a: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + + let mut source = Source::new([0u8; 32]); + + source.fill_bytes(svp.data_mut()); + source.fill_bytes(res.data_mut()); + source.fill_bytes(a.data_mut()); + + move || { + for j in 0..cols { + module.svp_apply_dft_to_dft_add(&mut res, j, &svp, j, &a, j); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_svp_apply_dft_to_dft_inplace(c: &mut Criterion, label: &str) +where + Module: SvpApplyDftToDftInplace + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, + B: Backend, +{ + let group_name: String = format!("svp_apply_dft_to_dft_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: SvpApplyDftToDftInplace + SvpPPolAlloc + ModuleNew + VecZnxDftAlloc, + B: Backend, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut svp: SvpPPol, B> = module.svp_ppol_alloc(cols); + let mut res: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + + let mut source = Source::new([0u8; 32]); + + source.fill_bytes(svp.data_mut()); + source.fill_bytes(res.data_mut()); + + move || { + for j in 0..cols { + module.svp_apply_dft_to_dft_inplace(&mut res, j, &svp, j); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 7], [13, 2, 15], [14, 2, 31]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/bench_suite/vec_znx.rs b/poulpy-hal/src/bench_suite/vec_znx.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/poulpy-hal/src/bench_suite/vec_znx.rs @@ -0,0 +1 @@ + diff --git a/poulpy-hal/src/bench_suite/vec_znx_big.rs b/poulpy-hal/src/bench_suite/vec_znx_big.rs new file mode 100644 index 0000000..2f05b35 --- /dev/null +++ b/poulpy-hal/src/bench_suite/vec_znx_big.rs @@ -0,0 +1,641 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; +use rand::RngCore; + +use crate::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddSmall, + VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, + VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, + VecZnxBigSubSmallB, + }, + layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig}, + source::Source, +}; + +pub fn bench_vec_znx_big_add(c: &mut Criterion, label: &str) +where + Module: VecZnxBigAdd + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_add::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigAdd + ModuleNew + VecZnxBigAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut b: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut c: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random bytes + source.fill_bytes(a.data_mut()); + source.fill_bytes(b.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_add(&mut c, i, &a, i, &b, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_add_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxBigAddInplace + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_add_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigAddInplace + ModuleNew + VecZnxBigAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut c: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_add_inplace(&mut c, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_add_small(c: &mut Criterion, label: &str) +where + Module: VecZnxBigAddSmall + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_add_small::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigAddSmall + ModuleNew + VecZnxBigAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut b: VecZnx> = VecZnx::alloc(n, cols, size); + let mut c: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + source.fill_bytes(b.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_add_small(&mut c, i, &a, i, &b, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_add_small_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxBigAddSmallInplace + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_add_small_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigAddSmallInplace + ModuleNew + VecZnxBigAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut c: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_add_small_inplace(&mut c, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_automorphism(c: &mut Criterion, label: &str) +where + Module: VecZnxBigAutomorphism + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_automorphism::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigAutomorphism + ModuleNew + VecZnxBigAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut res: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + source.fill_bytes(res.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_automorphism(-7, &mut res, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_automorphism_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxBigAutomorphismInplace + VecZnxBigAutomorphismInplaceTmpBytes + ModuleNew + VecZnxBigAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_automorphism_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigAutomorphismInplace + ModuleNew + VecZnxBigAutomorphismInplaceTmpBytes + VecZnxBigAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut res: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_big_automorphism_inplace_tmp_bytes()); + + // Fill a with random i64 + source.fill_bytes(res.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_automorphism_inplace(-7, &mut res, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_negate(c: &mut Criterion, label: &str) +where + Module: VecZnxBigNegate + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_negate::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigNegate + ModuleNew + VecZnxBigAlloc, + { + let module: Module = Module::::new(1 << params[0]); + + let cols: usize = params[1]; + let size: usize = params[2]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut b: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + source.fill_bytes(b.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_negate(&mut b, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_negate_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxBigNegateInplace + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_negate_big_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigNegateInplace + ModuleNew + VecZnxBigAlloc, + { + let module: Module = Module::::new(1 << params[0]); + + let cols: usize = params[1]; + let size: usize = params[2]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_negate_inplace(&mut a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_normalize(c: &mut Criterion, label: &str) +where + Module: VecZnxBigNormalize + ModuleNew + VecZnxBigNormalizeTmpBytes + VecZnxBigAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_big_normalize::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigNormalize + ModuleNew + VecZnxBigNormalizeTmpBytes + VecZnxBigAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let basek: usize = 50; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut res: VecZnx> = VecZnx::alloc(module.n(), cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + source.fill_bytes(res.data_mut()); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_big_normalize_tmp_bytes()); + + move || { + for i in 0..cols { + module.vec_znx_big_normalize(basek, &mut res, i, &a, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_sub(c: &mut Criterion, label: &str) +where + Module: VecZnxBigSub + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_sub::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigSub + ModuleNew + VecZnxBigAlloc, + { + let module: Module = Module::::new(1 << params[0]); + + let cols: usize = params[1]; + let size: usize = params[2]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut b: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut c: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random bytes + source.fill_bytes(a.data_mut()); + source.fill_bytes(b.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_sub(&mut c, i, &a, i, &b, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_sub_ab_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxBigSubABInplace + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_sub_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigSubABInplace + ModuleNew + VecZnxBigAlloc, + { + let module: Module = Module::::new(1 << params[0]); + + let cols: usize = params[1]; + let size: usize = params[2]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut c: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random bytes + source.fill_bytes(a.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_sub_ab_inplace(&mut c, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_sub_ba_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxBigSubBAInplace + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_sub_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigSubBAInplace + ModuleNew + VecZnxBigAlloc, + { + let module: Module = Module::::new(1 << params[0]); + + let cols: usize = params[1]; + let size: usize = params[2]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut c: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random bytes + source.fill_bytes(a.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_sub_ba_inplace(&mut c, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_sub_small_a(c: &mut Criterion, label: &str) +where + Module: VecZnxBigSubSmallA + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_sub_small_a::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigSubSmallA + ModuleNew + VecZnxBigAlloc, + { + let module: Module = Module::::new(1 << params[0]); + + let cols: usize = params[1]; + let size: usize = params[2]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(module.n(), cols, size); + let mut b: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut c: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random bytes + source.fill_bytes(a.data_mut()); + source.fill_bytes(b.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_sub_small_a(&mut c, i, &a, i, &b, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_big_sub_small_b(c: &mut Criterion, label: &str) +where + Module: VecZnxBigSubSmallB + ModuleNew + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_big_sub_small_b::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxBigSubSmallB + ModuleNew + VecZnxBigAlloc, + { + let module: Module = Module::::new(1 << params[0]); + + let cols: usize = params[1]; + let size: usize = params[2]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut b: VecZnx> = VecZnx::alloc(module.n(), cols, size); + let mut c: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + + // Fill a with random bytes + source.fill_bytes(a.data_mut()); + source.fill_bytes(b.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_big_sub_small_b(&mut c, i, &a, i, &b, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/bench_suite/vec_znx_dft.rs b/poulpy-hal/src/bench_suite/vec_znx_dft.rs new file mode 100644 index 0000000..ac2f758 --- /dev/null +++ b/poulpy-hal/src/bench_suite/vec_znx_dft.rs @@ -0,0 +1,365 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; +use rand::RngCore; + +use crate::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, + VecZnxDftApply, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyTmpA, + VecZnxIdftApplyTmpBytes, + }, + layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft}, + source::Source, +}; + +pub fn bench_vec_znx_dft_add(c: &mut Criterion, label: &str) +where + Module: VecZnxDftAdd + ModuleNew + VecZnxDftAlloc, +{ + let group_name: String = format!("vec_znx_dft_add::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxDftAdd + ModuleNew + VecZnxDftAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut b: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut c: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + + source.fill_bytes(a.data_mut()); + source.fill_bytes(b.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_dft_add(&mut c, i, &a, i, &b, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_dft_add_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxDftAddInplace + ModuleNew + VecZnxDftAlloc, +{ + let group_name: String = format!("vec_znx_dft_add_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxDftAddInplace + ModuleNew + VecZnxDftAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut c: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_dft_add_inplace(&mut c, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_dft_apply(c: &mut Criterion, label: &str) +where + Module: VecZnxDftApply + ModuleNew + VecZnxDftAlloc, +{ + let group_name: String = format!("vec_znx_dft_apply::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxDftApply + ModuleNew + VecZnxDftAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut res: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + source.fill_bytes(res.data_mut()); + source.fill_bytes(a.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_dft_apply(1, 0, &mut res, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_idft_apply(c: &mut Criterion, label: &str) +where + Module: VecZnxIdftApply + ModuleNew + VecZnxIdftApplyTmpBytes + VecZnxDftAlloc + VecZnxBigAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_idft_apply::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxIdftApply + ModuleNew + VecZnxIdftApplyTmpBytes + VecZnxDftAlloc + VecZnxBigAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut res: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut a: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + source.fill_bytes(res.data_mut()); + source.fill_bytes(a.data_mut()); + + let mut scratch = ScratchOwned::alloc(module.vec_znx_idft_apply_tmp_bytes()); + + move || { + for i in 0..cols { + module.vec_znx_idft_apply(&mut res, i, &a, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_idft_apply_tmpa(c: &mut Criterion, label: &str) +where + Module: VecZnxIdftApplyTmpA + ModuleNew + VecZnxDftAlloc + VecZnxBigAlloc, +{ + let group_name: String = format!("vec_znx_idft_apply_tmpa::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxIdftApplyTmpA + ModuleNew + VecZnxDftAlloc + VecZnxBigAlloc, + { + let module: Module = Module::::new(1 << params[0]); + + let cols: usize = params[1]; + let size: usize = params[2]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut res: VecZnxBig, B> = module.vec_znx_big_alloc(cols, size); + let mut a: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + source.fill_bytes(res.data_mut()); + source.fill_bytes(a.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_idft_apply_tmpa(&mut res, i, &mut a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_dft_sub(c: &mut Criterion, label: &str) +where + Module: VecZnxDftSub + ModuleNew + VecZnxDftAlloc, +{ + let group_name: String = format!("vec_znx_dft_sub::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxDftSub + ModuleNew + VecZnxDftAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut b: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut c: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + + source.fill_bytes(a.data_mut()); + source.fill_bytes(b.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_dft_sub(&mut c, i, &a, i, &b, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_dft_sub_ab_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxDftSubABInplace + ModuleNew + VecZnxDftAlloc, +{ + let group_name: String = format!("vec_znx_dft_sub_ab_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxDftSubABInplace + ModuleNew + VecZnxDftAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut c: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_dft_sub_ab_inplace(&mut c, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_dft_sub_ba_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxDftSubBAInplace + ModuleNew + VecZnxDftAlloc, +{ + let group_name: String = format!("vec_znx_dft_sub_ba_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxDftSubBAInplace + ModuleNew + VecZnxDftAlloc, + { + let n: usize = params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + let mut c: VecZnxDft, B> = module.vec_znx_dft_alloc(cols, size); + + // Fill a with random i64 + source.fill_bytes(a.data_mut()); + source.fill_bytes(c.data_mut()); + + move || { + for i in 0..cols { + module.vec_znx_dft_sub_ba_inplace(&mut c, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/bench_suite/vmp.rs b/poulpy-hal/src/bench_suite/vmp.rs new file mode 100644 index 0000000..0fa2ff9 --- /dev/null +++ b/poulpy-hal/src/bench_suite/vmp.rs @@ -0,0 +1,259 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; +use rand::RngCore; + +use crate::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxDftAlloc, VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, VmpPrepareTmpBytes, + }, + layouts::{Backend, DataViewMut, MatZnx, Module, ScratchOwned, VecZnx, VecZnxDft, VmpPMat}, + source::Source, +}; + +pub fn bench_vmp_prepare(c: &mut Criterion, label: &str) +where + Module: ModuleNew + VmpPMatAlloc + VmpPrepare + VmpPrepareTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vmp_prepare::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 5]) -> impl FnMut() + where + Module: ModuleNew + VmpPMatAlloc + VmpPrepare + VmpPrepareTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let module: Module = Module::::new(1 << params[0]); + + let rows: usize = params[1]; + let cols_in: usize = params[2]; + let cols_out: usize = params[3]; + let size: usize = params[4]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vmp_prepare_tmp_bytes(rows, cols_in, cols_out, size)); + + let mut mat: MatZnx> = MatZnx::alloc(module.n(), rows, cols_in, cols_out, size); + let mut pmat: VmpPMat, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size); + + source.fill_bytes(mat.data_mut()); + source.fill_bytes(pmat.data_mut()); + + move || { + module.vmp_prepare(&mut pmat, &mat, scratch.borrow()); + black_box(()); + } + } + + for params in [ + [10, 2, 1, 2, 3], + [11, 4, 1, 2, 5], + [12, 7, 1, 2, 8], + [13, 15, 1, 2, 16], + [14, 31, 1, 2, 32], + ] { + let id = BenchmarkId::from_parameter(format!( + "{}x({}x{})x({}x{})", + 1 << params[0], + params[2], + params[1], + params[3], + params[4] + )); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vmp_apply_dft(c: &mut Criterion, label: &str) +where + Module: ModuleNew + VmpApplyDftTmpBytes + VmpApplyDft + VmpPMatAlloc + VecZnxDftAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vmp_apply_dft::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 5]) -> impl FnMut() + where + Module: ModuleNew + VmpApplyDftTmpBytes + VmpApplyDft + VmpPMatAlloc + VecZnxDftAlloc, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let module: Module = Module::::new(1 << params[0]); + + let rows: usize = params[1]; + let cols_in: usize = params[2]; + let cols_out: usize = params[3]; + let size: usize = params[4]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 20); + + let mut res: VecZnxDft, _> = module.vec_znx_dft_alloc(cols_out, size); + let mut a: VecZnx> = VecZnx::alloc(module.n(), cols_in, size); + let mut pmat: VmpPMat, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size); + + source.fill_bytes(pmat.data_mut()); + source.fill_bytes(res.data_mut()); + source.fill_bytes(a.data_mut()); + + move || { + module.vmp_apply_dft(&mut res, &a, &pmat, scratch.borrow()); + black_box(()); + } + } + + for params in [ + [10, 2, 1, 2, 3], + [11, 4, 1, 2, 5], + [12, 7, 1, 2, 8], + [13, 15, 1, 2, 16], + [14, 31, 1, 2, 32], + ] { + let id = BenchmarkId::from_parameter(format!( + "{}x({}x{})x({}x{})", + 1 << params[0], + params[2], + params[1], + params[3], + params[4] + )); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vmp_apply_dft_to_dft(c: &mut Criterion, label: &str) +where + Module: ModuleNew + VecZnxDftAlloc + VmpPMatAlloc + VmpApplyDftToDft + VmpApplyDftToDftTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vmp_apply_dft_to_dft::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 5]) -> impl FnMut() + where + Module: ModuleNew + VecZnxDftAlloc + VmpPMatAlloc + VmpApplyDftToDft + VmpApplyDftToDftTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let module: Module = Module::::new(1 << params[0]); + + let rows: usize = params[1]; + let cols_in: usize = params[2]; + let cols_out: usize = params[3]; + let size: usize = params[4]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = + ScratchOwned::alloc(module.vmp_apply_dft_to_dft_tmp_bytes(size, size, rows, cols_in, cols_out, size)); + + let mut res: VecZnxDft, _> = module.vec_znx_dft_alloc(cols_out, size); + let mut a: VecZnxDft, _> = module.vec_znx_dft_alloc(cols_in, size); + + let mut pmat: VmpPMat, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size); + + source.fill_bytes(pmat.data_mut()); + source.fill_bytes(res.data_mut()); + source.fill_bytes(a.data_mut()); + + move || { + module.vmp_apply_dft_to_dft(&mut res, &a, &pmat, scratch.borrow()); + black_box(()); + } + } + + for params in [ + [10, 2, 1, 2, 3], + [11, 4, 1, 2, 5], + [12, 7, 1, 2, 8], + [13, 15, 1, 2, 16], + [14, 31, 1, 2, 32], + ] { + let id = BenchmarkId::from_parameter(format!( + "{}x({}x{})x({}x{})", + 1 << params[0], // n + params[2], // cols_in + params[1], // size_in (=rows) + params[3], // cols_out + params[4] // size_out + )); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vmp_apply_dft_to_dft_add(c: &mut Criterion, label: &str) +where + Module: ModuleNew + VecZnxDftAlloc + VmpPMatAlloc + VmpApplyDftToDftAdd + VmpApplyDftToDftAddTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vmp_apply_dft_to_dft_add::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 5]) -> impl FnMut() + where + Module: ModuleNew + VecZnxDftAlloc + VmpPMatAlloc + VmpApplyDftToDftAdd + VmpApplyDftToDftAddTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let module: Module = Module::::new(1 << params[0]); + + let rows: usize = params[1]; + let cols_in: usize = params[2]; + let cols_out: usize = params[3]; + let size: usize = params[4]; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = + ScratchOwned::alloc(module.vmp_apply_dft_to_dft_add_tmp_bytes(size, size, rows, cols_in, cols_out, size)); + + let mut res: VecZnxDft, _> = module.vec_znx_dft_alloc(cols_out, size); + let mut a: VecZnxDft, _> = module.vec_znx_dft_alloc(cols_in, size); + + let mut pmat: VmpPMat, B> = module.vmp_pmat_alloc(rows, cols_in, cols_out, size); + + source.fill_bytes(pmat.data_mut()); + source.fill_bytes(res.data_mut()); + source.fill_bytes(a.data_mut()); + + move || { + module.vmp_apply_dft_to_dft_add(&mut res, &a, &pmat, 1, scratch.borrow()); + black_box(()); + } + } + + for params in [ + [10, 2, 1, 2, 3], + [11, 4, 1, 2, 5], + [12, 7, 1, 2, 8], + [13, 15, 1, 2, 16], + [14, 31, 1, 2, 32], + ] { + let id = BenchmarkId::from_parameter(format!( + "{}x({}x{})x({}x{})", + 1 << params[0], + params[2], + params[1], + params[3], + params[4] + )); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/delegates/svp_ppol.rs b/poulpy-hal/src/delegates/svp_ppol.rs index 86e4cc0..54a99b2 100644 --- a/poulpy-hal/src/delegates/svp_ppol.rs +++ b/poulpy-hal/src/delegates/svp_ppol.rs @@ -1,7 +1,15 @@ use crate::{ - api::{SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPPolFromBytes, SvpPrepare}, - layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}, - oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl}, + api::{ + SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + SvpPPolFromBytes, SvpPrepare, + }, + layouts::{ + Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, + }, + oep::{ + SvpApplyDftImpl, SvpApplyDftToDftAddImpl, SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl, + SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl, + }, }; impl SvpPPolFromBytes for Module @@ -44,29 +52,57 @@ where } } -impl SvpApply for Module +impl SvpApplyDft for Module where - B: Backend + SvpApplyImpl, + B: Backend + SvpApplyDftImpl, { - fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + fn svp_apply_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + C: VecZnxToRef, + { + B::svp_apply_dft_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl SvpApplyDftToDft for Module +where + B: Backend + SvpApplyDftToDftImpl, +{ + fn svp_apply_dft_to_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) where R: VecZnxDftToMut, A: SvpPPolToRef, C: VecZnxDftToRef, { - B::svp_apply_impl(self, res, res_col, a, a_col, b, b_col); + B::svp_apply_dft_to_dft_impl(self, res, res_col, a, a_col, b, b_col); } } -impl SvpApplyInplace for Module +impl SvpApplyDftToDftAdd for Module where - B: Backend + SvpApplyInplaceImpl, + B: Backend + SvpApplyDftToDftAddImpl, { - fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn svp_apply_dft_to_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) + where + R: VecZnxDftToMut, + A: SvpPPolToRef, + C: VecZnxDftToRef, + { + B::svp_apply_dft_to_dft_add_impl(self, res, res_col, a, a_col, b, b_col); + } +} + +impl SvpApplyDftToDftInplace for Module +where + B: Backend + SvpApplyDftToDftInplaceImpl, +{ + fn svp_apply_dft_to_dft_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: SvpPPolToRef, { - B::svp_apply_inplace_impl(self, res, res_col, a, a_col); + B::svp_apply_dft_to_dft_inplace_impl(self, res, res_col, a, a_col); } } diff --git a/poulpy-hal/src/delegates/vec_znx.rs b/poulpy-hal/src/delegates/vec_znx.rs index e1e3b92..a5cd36f 100644 --- a/poulpy-hal/src/delegates/vec_znx.rs +++ b/poulpy-hal/src/delegates/vec_znx.rs @@ -1,19 +1,24 @@ use crate::{ api::{ - VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, - VecZnxAutomorphismInplace, VecZnxCopy, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace, - VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit, - VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree, + VecZnxAdd, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalar, VecZnxAddScalarInplace, VecZnxAutomorphism, + VecZnxAutomorphismInplace, VecZnxAutomorphismInplaceTmpBytes, VecZnxCopy, VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh, + VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, + VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, + VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, }, layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, oep::{ - VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl, - VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, - VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, - VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, - VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxSubABInplaceImpl, - VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl, + VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl, + VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl, + VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl, + VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, + VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, + VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, + VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, + VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, }, source::Source, }; @@ -79,6 +84,20 @@ where } } +impl VecZnxAddScalar for Module +where + B: Backend + VecZnxAddScalarImpl, +{ + fn vec_znx_add_scalar(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize, b_limb: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + D: VecZnxToRef, + { + B::vec_znx_add_scalar_impl(self, res, res_col, a, a_col, b, b_col, b_limb) + } +} + impl VecZnxAddScalarInplace for Module where B: Backend + VecZnxAddScalarInplaceImpl, @@ -132,6 +151,20 @@ where } } +impl VecZnxSubScalar for Module +where + B: Backend + VecZnxSubScalarImpl, +{ + fn vec_znx_sub_scalar(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize, b_limb: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + D: VecZnxToRef, + { + B::vec_znx_sub_scalar_impl(self, res, res_col, a, a_col, b, b_col, b_limb) + } +} + impl VecZnxSubScalarInplace for Module where B: Backend + VecZnxSubScalarInplaceImpl, @@ -170,27 +203,87 @@ where } } -impl VecZnxLshInplace for Module +impl VecZnxRshTmpBytes for Module where - B: Backend + VecZnxLshInplaceImpl, + B: Backend + VecZnxRshTmpBytesImpl, { - fn vec_znx_lsh_inplace(&self, basek: usize, k: usize, a: &mut A) - where - A: VecZnxToMut, - { - B::vec_znx_lsh_inplace_impl(self, basek, k, a) + fn vec_znx_rsh_tmp_bytes(&self) -> usize { + B::vec_znx_rsh_tmp_bytes_impl(self) } } -impl VecZnxRshInplace for Module +impl VecZnxLshTmpBytes for Module where - B: Backend + VecZnxRshInplaceImpl, + B: Backend + VecZnxLshTmpBytesImpl, { - fn vec_znx_rsh_inplace(&self, basek: usize, k: usize, a: &mut A) + fn vec_znx_lsh_tmp_bytes(&self) -> usize { + B::vec_znx_lsh_tmp_bytes_impl(self) + } +} + +impl VecZnxLsh for Module +where + B: Backend + VecZnxLshImpl, +{ + fn vec_znx_lsh( + &self, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_lsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch); + } +} + +impl VecZnxRsh for Module +where + B: Backend + VecZnxRshImpl, +{ + fn vec_znx_rsh( + &self, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef, + { + B::vec_znx_rsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch); + } +} + +impl VecZnxLshInplace for Module +where + B: Backend + VecZnxLshInplaceImpl, +{ + fn vec_znx_lsh_inplace(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut, { - B::vec_znx_rsh_inplace_impl(self, basek, k, a) + B::vec_znx_lsh_inplace_impl(self, basek, k, a, a_col, scratch) + } +} + +impl VecZnxRshInplace for Module +where + B: Backend + VecZnxRshInplaceImpl, +{ + fn vec_znx_rsh_inplace(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut, + { + B::vec_znx_rsh_inplace_impl(self, basek, k, a, a_col, scratch) } } @@ -207,15 +300,24 @@ where } } -impl VecZnxRotateInplace for Module +impl VecZnxRotateInplaceTmpBytes for Module +where + B: Backend + VecZnxRotateInplaceTmpBytesImpl, +{ + fn vec_znx_rotate_inplace_tmp_bytes(&self) -> usize { + B::vec_znx_rotate_inplace_tmp_bytes_impl(self) + } +} + +impl VecZnxRotateInplace for Module where B: Backend + VecZnxRotateInplaceImpl, { - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut, { - B::vec_znx_rotate_inplace_impl(self, k, a, a_col) + B::vec_znx_rotate_inplace_impl(self, k, a, a_col, scratch) } } @@ -232,15 +334,24 @@ where } } -impl VecZnxAutomorphismInplace for Module +impl VecZnxAutomorphismInplaceTmpBytes for Module +where + B: Backend + VecZnxAutomorphismInplaceTmpBytesImpl, +{ + fn vec_znx_automorphism_inplace_tmp_bytes(&self) -> usize { + B::vec_znx_automorphism_inplace_tmp_bytes_impl(self) + } +} + +impl VecZnxAutomorphismInplace for Module where B: Backend + VecZnxAutomorphismInplaceImpl, { - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn vec_znx_automorphism_inplace(&self, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch) where - A: VecZnxToMut, + R: VecZnxToMut, { - B::vec_znx_automorphism_inplace_impl(self, k, a, a_col) + B::vec_znx_automorphism_inplace_impl(self, k, res, res_col, scratch) } } @@ -257,54 +368,81 @@ where } } -impl VecZnxMulXpMinusOneInplace for Module +impl VecZnxMulXpMinusOneInplaceTmpBytes for Module +where + B: Backend + VecZnxMulXpMinusOneInplaceTmpBytesImpl, +{ + fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(&self) -> usize { + B::vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(self) + } +} + +impl VecZnxMulXpMinusOneInplace for Module where B: Backend + VecZnxMulXpMinusOneInplaceImpl, { - fn vec_znx_mul_xp_minus_one_inplace(&self, p: i64, res: &mut R, res_col: usize) + fn vec_znx_mul_xp_minus_one_inplace(&self, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch) where R: VecZnxToMut, { - B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col); + B::vec_znx_mul_xp_minus_one_inplace_impl(self, p, res, res_col, scratch); } } -impl VecZnxSplit for Module +impl VecZnxSplitRingTmpBytes for Module where - B: Backend + VecZnxSplitImpl, + B: Backend + VecZnxSplitRingTmpBytesImpl, { - fn vec_znx_split(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_split_ring_tmp_bytes(&self) -> usize { + B::vec_znx_split_ring_tmp_bytes_impl(self) + } +} + +impl VecZnxSplitRing for Module +where + B: Backend + VecZnxSplitRingImpl, +{ + fn vec_znx_split_ring(&self, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where R: VecZnxToMut, A: VecZnxToRef, { - B::vec_znx_split_impl(self, res, res_col, a, a_col, scratch) + B::vec_znx_split_ring_impl(self, res, res_col, a, a_col, scratch) } } -impl VecZnxMerge for Module +impl VecZnxMergeRingsTmpBytes for Module where - B: Backend + VecZnxMergeImpl, + B: Backend + VecZnxMergeRingsTmpBytesImpl, { - fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize) + fn vec_znx_merge_rings_tmp_bytes(&self) -> usize { + B::vec_znx_merge_rings_tmp_bytes_impl(self) + } +} + +impl VecZnxMergeRings for Module +where + B: Backend + VecZnxMergeRingsImpl, +{ + fn vec_znx_merge_rings(&self, res: &mut R, res_col: usize, a: &[A], a_col: usize, scratch: &mut Scratch) where R: VecZnxToMut, A: VecZnxToRef, { - B::vec_znx_merge_impl(self, res, res_col, a, a_col) + B::vec_znx_merge_rings_impl(self, res, res_col, a, a_col, scratch) } } -impl VecZnxSwithcDegree for Module +impl VecZnxSwitchRing for Module where - B: Backend + VecZnxSwithcDegreeImpl, + B: Backend + VecZnxSwitchRingImpl, { - fn vec_znx_switch_degree(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_switch_ring(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef, { - B::vec_znx_switch_degree_impl(self, res, res_col, a, a_col) + B::vec_znx_switch_ring_impl(self, res, res_col, a, a_col) } } @@ -325,51 +463,11 @@ impl VecZnxFillUniform for Module where B: Backend + VecZnxFillUniformImpl, { - fn vec_znx_fill_uniform(&self, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + fn vec_znx_fill_uniform(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut, { - B::vec_znx_fill_uniform_impl(self, basek, res, res_col, k, source); - } -} - -impl VecZnxFillDistF64 for Module -where - B: Backend + VecZnxFillDistF64Impl, -{ - fn vec_znx_fill_dist_f64>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: VecZnxToMut, - { - B::vec_znx_fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound); - } -} - -impl VecZnxAddDistF64 for Module -where - B: Backend + VecZnxAddDistF64Impl, -{ - fn vec_znx_add_dist_f64>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: VecZnxToMut, - { - B::vec_znx_add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound); + B::vec_znx_fill_uniform_impl(self, basek, res, res_col, source); } } diff --git a/poulpy-hal/src/delegates/vec_znx_big.rs b/poulpy-hal/src/delegates/vec_znx_big.rs index 953949f..1d0f8f1 100644 --- a/poulpy-hal/src/delegates/vec_znx_big.rs +++ b/poulpy-hal/src/delegates/vec_znx_big.rs @@ -1,18 +1,16 @@ -use rand_distr::Distribution; - use crate::{ api::{ - VecZnxBigAdd, VecZnxBigAddDistF64, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, - VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigFillDistF64, - VecZnxBigFillNormal, VecZnxBigFromBytes, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace, - VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace, + VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc, + VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes, + VecZnxBigFromBytes, VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, + VecZnxBigSubSmallAInplace, VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace, }, layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef}, oep::{ - VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, - VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, - VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl, VecZnxBigFromBytesImpl, + VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl, + VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, + VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl, @@ -20,6 +18,19 @@ use crate::{ source::Source, }; +impl VecZnxBigFromSmall for Module +where + B: Backend + VecZnxBigFromSmallImpl, +{ + fn vec_znx_big_from_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + B::vec_znx_big_from_small_impl(res, res_col, a, a_col); + } +} + impl VecZnxBigAlloc for Module where B: Backend + VecZnxBigAllocImpl, @@ -47,24 +58,6 @@ where } } -impl VecZnxBigAddDistF64 for Module -where - B: Backend + VecZnxBigAddDistF64Impl, -{ - fn vec_znx_big_add_dist_f64, D: Distribution>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) { - B::add_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound); - } -} - impl VecZnxBigAddNormal for Module where B: Backend + VecZnxBigAddNormalImpl, @@ -83,42 +76,6 @@ where } } -impl VecZnxBigFillDistF64 for Module -where - B: Backend + VecZnxBigFillDistF64Impl, -{ - fn vec_znx_big_fill_dist_f64, D: Distribution>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) { - B::fill_dist_f64_impl(self, basek, res, res_col, k, source, dist, bound); - } -} - -impl VecZnxBigFillNormal for Module -where - B: Backend + VecZnxBigFillNormalImpl, -{ - fn vec_znx_big_fill_normal>( - &self, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) { - B::fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound); - } -} - impl VecZnxBigAdd for Module where B: Backend + VecZnxBigAddImpl, @@ -267,6 +224,19 @@ where } } +impl VecZnxBigNegate for Module +where + B: Backend + VecZnxBigNegateImpl, +{ + fn vec_znx_big_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + B::vec_znx_big_negate_impl(self, res, res_col, a, a_col); + } +} + impl VecZnxBigNegateInplace for Module where B: Backend + VecZnxBigNegateInplaceImpl, @@ -321,14 +291,23 @@ where } } +impl VecZnxBigAutomorphismInplaceTmpBytes for Module +where + B: Backend + VecZnxBigAutomorphismInplaceTmpBytesImpl, +{ + fn vec_znx_big_automorphism_inplace_tmp_bytes(&self) -> usize { + B::vec_znx_big_automorphism_inplace_tmp_bytes_impl(self) + } +} + impl VecZnxBigAutomorphismInplace for Module where B: Backend + VecZnxBigAutomorphismInplaceImpl, { - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxBigToMut, { - B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col); + B::vec_znx_big_automorphism_inplace_impl(self, k, a, a_col, scratch); } } diff --git a/poulpy-hal/src/delegates/vec_znx_dft.rs b/poulpy-hal/src/delegates/vec_znx_dft.rs index 28a9fdf..b486b08 100644 --- a/poulpy-hal/src/delegates/vec_znx_dft.rs +++ b/poulpy-hal/src/delegates/vec_znx_dft.rs @@ -1,16 +1,17 @@ use crate::{ api::{ - DFT, IDFT, IDFTConsume, IDFTTmpA, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, - VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIDFTTmpBytes, + VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, + VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIdftApply, + VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, }, layouts::{ Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, }, oep::{ - DFTImpl, IDFTConsumeImpl, IDFTImpl, IDFTTmpAImpl, VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, - VecZnxDftAllocImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, - VecZnxDftSubImpl, VecZnxDftZeroImpl, VecZnxIDFTTmpBytesImpl, + VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl, + VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl, + VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl, }, }; @@ -41,63 +42,63 @@ where } } -impl VecZnxIDFTTmpBytes for Module +impl VecZnxIdftApplyTmpBytes for Module where - B: Backend + VecZnxIDFTTmpBytesImpl, + B: Backend + VecZnxIdftApplyTmpBytesImpl, { - fn vec_znx_idft_tmp_bytes(&self) -> usize { - B::vec_znx_idft_tmp_bytes_impl(self) + fn vec_znx_idft_apply_tmp_bytes(&self) -> usize { + B::vec_znx_idft_apply_tmp_bytes_impl(self) } } -impl IDFT for Module +impl VecZnxIdftApply for Module where - B: Backend + IDFTImpl, + B: Backend + VecZnxIdftApplyImpl, { - fn idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + fn vec_znx_idft_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where R: VecZnxBigToMut, A: VecZnxDftToRef, { - B::idft_impl(self, res, res_col, a, a_col, scratch); + B::vec_znx_idft_apply_impl(self, res, res_col, a, a_col, scratch); } } -impl IDFTTmpA for Module +impl VecZnxIdftApplyTmpA for Module where - B: Backend + IDFTTmpAImpl, + B: Backend + VecZnxIdftApplyTmpAImpl, { - fn idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + fn vec_znx_idft_apply_tmpa(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxDftToMut, { - B::idft_tmp_a_impl(self, res, res_col, a, a_col); + B::vec_znx_idft_apply_tmpa_impl(self, res, res_col, a, a_col); } } -impl IDFTConsume for Module +impl VecZnxIdftApplyConsume for Module where - B: Backend + IDFTConsumeImpl, + B: Backend + VecZnxIdftApplyConsumeImpl, { - fn vec_znx_idft_consume(&self, a: VecZnxDft) -> VecZnxBig + fn vec_znx_idft_apply_consume(&self, a: VecZnxDft) -> VecZnxBig where VecZnxDft: VecZnxDftToMut, { - B::idft_consume_impl(self, a) + B::vec_znx_idft_apply_consume_impl(self, a) } } -impl DFT for Module +impl VecZnxDftApply for Module where - B: Backend + DFTImpl, + B: Backend + VecZnxDftApplyImpl, { - fn dft(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_dft_apply(&self, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: VecZnxToRef, { - B::dft_impl(self, step, offset, res, res_col, a, a_col); + B::vec_znx_dft_apply_impl(self, step, offset, res, res_col, a, a_col); } } diff --git a/poulpy-hal/src/delegates/vmp_pmat.rs b/poulpy-hal/src/delegates/vmp_pmat.rs index 33cbc15..a875a40 100644 --- a/poulpy-hal/src/delegates/vmp_pmat.rs +++ b/poulpy-hal/src/delegates/vmp_pmat.rs @@ -1,12 +1,16 @@ use crate::{ api::{ - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, - VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, + VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes, + VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatAllocBytes, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, + }, + layouts::{ + Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, + VmpPMatToRef, }, - layouts::{Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef}, oep::{ - VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl, - VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl, + VmpApplyDftImpl, VmpApplyDftTmpBytesImpl, VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, + VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, + VmpPrepareTmpBytesImpl, }, }; @@ -48,7 +52,7 @@ where impl VmpPrepare for Module where - B: Backend + VmpPMatPrepareImpl, + B: Backend + VmpPrepareImpl, { fn vmp_prepare(&self, res: &mut R, a: &A, scratch: &mut Scratch) where @@ -59,6 +63,39 @@ where } } +impl VmpApplyDftTmpBytes for Module +where + B: Backend + VmpApplyDftTmpBytesImpl, +{ + fn vmp_apply_dft_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + B::vmp_apply_dft_tmp_bytes_impl( + self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, + ) + } +} + +impl VmpApplyDft for Module +where + B: Backend + VmpApplyDftImpl, +{ + fn vmp_apply_dft(&self, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxToRef, + C: VmpPMatToRef, + { + B::vmp_apply_dft_impl(self, res, a, b, scratch); + } +} + impl VmpApplyDftToDftTmpBytes for Module where B: Backend + VmpApplyDftToDftTmpBytesImpl, diff --git a/poulpy-hal/src/delegates/zn.rs b/poulpy-hal/src/delegates/zn.rs index e5311bb..450bdc9 100644 --- a/poulpy-hal/src/delegates/zn.rs +++ b/poulpy-hal/src/delegates/zn.rs @@ -1,10 +1,19 @@ use crate::{ - api::{ZnAddDistF64, ZnAddNormal, ZnFillDistF64, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace}, + api::{ZnAddNormal, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace, ZnNormalizeTmpBytes}, layouts::{Backend, Module, Scratch, ZnToMut}, - oep::{ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl}, + oep::{ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl}, source::Source, }; +impl ZnNormalizeTmpBytes for Module +where + B: Backend + ZnNormalizeTmpBytesImpl, +{ + fn zn_normalize_tmp_bytes(&self, n: usize) -> usize { + B::zn_normalize_tmp_bytes_impl(n) + } +} + impl ZnNormalizeInplace for Module where B: Backend + ZnNormalizeInplaceImpl, @@ -21,53 +30,11 @@ impl ZnFillUniform for Module where B: Backend + ZnFillUniformImpl, { - fn zn_fill_uniform(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + fn zn_fill_uniform(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut, { - B::zn_fill_uniform_impl(n, basek, res, res_col, k, source); - } -} - -impl ZnFillDistF64 for Module -where - B: Backend + ZnFillDistF64Impl, -{ - fn zn_fill_dist_f64>( - &self, - n: usize, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: ZnToMut, - { - B::zn_fill_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound); - } -} - -impl ZnAddDistF64 for Module -where - B: Backend + ZnAddDistF64Impl, -{ - fn zn_add_dist_f64>( - &self, - n: usize, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: ZnToMut, - { - B::zn_add_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound); + B::zn_fill_uniform_impl(n, basek, res, res_col, source); } } diff --git a/poulpy-hal/src/layouts/encoding.rs b/poulpy-hal/src/layouts/encoding.rs index b79002c..957d835 100644 --- a/poulpy-hal/src/layouts/encoding.rs +++ b/poulpy-hal/src/layouts/encoding.rs @@ -1,8 +1,9 @@ use itertools::izip; use rug::{Assign, Float}; -use crate::layouts::{ - DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, +use crate::{ + layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::znx_zero_ref, }; impl VecZnx { @@ -28,7 +29,7 @@ impl VecZnx { // Zeroes coefficients of the i-th column (0..a.size()).for_each(|i| { - a.zero_at(col, i); + znx_zero_ref(a.at_mut(col, i)); }); // If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy @@ -183,7 +184,7 @@ impl VecZnx { let prec: u32 = (basek * size) as u32; // 2^{basek} - let base = Float::with_val(prec, (1 << basek) as f64); + let base = Float::with_val(prec, (1u64 << basek) as f64); // y[i] = sum x[j][i] * 2^{-basek*j} (0..size).for_each(|i| { diff --git a/poulpy-hal/src/layouts/mat_znx.rs b/poulpy-hal/src/layouts/mat_znx.rs index c0ec62c..59a1c4c 100644 --- a/poulpy-hal/src/layouts/mat_znx.rs +++ b/poulpy-hal/src/layouts/mat_znx.rs @@ -1,17 +1,21 @@ use crate::{ alloc_aligned, layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, ZnxInfos, - ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, + ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, }; -use std::fmt; +use std::{ + fmt, + hash::{DefaultHasher, Hasher}, +}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use rand::RngCore; -#[derive(PartialEq, Eq, Clone)] +#[repr(C)] +#[derive(PartialEq, Eq, Clone, Hash)] pub struct MatZnx { data: D, n: usize, @@ -21,6 +25,19 @@ pub struct MatZnx { cols_out: usize, } +impl DigestU64 for MatZnx { + fn digest_u64(&self) -> u64 { + let mut h: DefaultHasher = DefaultHasher::new(); + h.write(self.data.as_ref()); + h.write_usize(self.n); + h.write_usize(self.size); + h.write_usize(self.rows); + h.write_usize(self.cols_in); + h.write_usize(self.cols_out); + h.finish() + } +} + impl ToOwnedDeep for MatZnx { type Owned = MatZnx>; fn to_owned_deep(&self) -> Self::Owned { @@ -57,6 +74,10 @@ impl ZnxInfos for MatZnx { fn size(&self) -> usize { self.size } + + fn poly_count(&self) -> usize { + self.rows() * self.cols_in() * self.cols_out() * self.size() + } } impl ZnxSliceSize for MatZnx { @@ -175,8 +196,18 @@ impl MatZnx { } impl FillUniform for MatZnx { - fn fill_uniform(&mut self, source: &mut Source) { - source.fill_bytes(self.data.as_mut()); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + match log_bound { + 64 => source.fill_bytes(self.data.as_mut()), + 0 => panic!("invalid log_bound, cannot be zero"), + _ => { + let mask: u64 = (1u64 << log_bound) - 1; + for x in self.raw_mut().iter_mut() { + let r = source.next_u64() & mask; + *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound); + } + } + } } } diff --git a/poulpy-hal/src/layouts/mod.rs b/poulpy-hal/src/layouts/mod.rs index 33a2c1f..d2b8fe5 100644 --- a/poulpy-hal/src/layouts/mod.rs +++ b/poulpy-hal/src/layouts/mod.rs @@ -34,3 +34,7 @@ pub trait ToOwnedDeep { type Owned; fn to_owned_deep(&self) -> Self::Owned; } + +pub trait DigestU64 { + fn digest_u64(&self) -> u64; +} diff --git a/poulpy-hal/src/layouts/module.rs b/poulpy-hal/src/layouts/module.rs index e885b13..61e312c 100644 --- a/poulpy-hal/src/layouts/module.rs +++ b/poulpy-hal/src/layouts/module.rs @@ -20,7 +20,30 @@ pub struct Module { _marker: PhantomData, } +unsafe impl Sync for Module {} +unsafe impl Send for Module {} + impl Module { + #[allow(clippy::missing_safety_doc)] + #[inline] + pub fn new_marker(n: u64) -> Self { + Self { + ptr: NonNull::dangling(), + n, + _marker: PhantomData, + } + } + + #[allow(clippy::missing_safety_doc)] + #[inline] + pub unsafe fn from_nonnull(ptr: NonNull, n: u64) -> Self { + Self { + ptr, + n, + _marker: PhantomData, + } + } + /// Construct from a raw pointer managed elsewhere. /// SAFETY: `ptr` must be non-null and remain valid for the lifetime of this Module. #[inline] diff --git a/poulpy-hal/src/layouts/scalar_znx.rs b/poulpy-hal/src/layouts/scalar_znx.rs index 4e83688..6baf66f 100644 --- a/poulpy-hal/src/layouts/scalar_znx.rs +++ b/poulpy-hal/src/layouts/scalar_znx.rs @@ -1,3 +1,5 @@ +use std::hash::{DefaultHasher, Hasher}; + use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; @@ -5,19 +7,30 @@ use rand_distr::{Distribution, weighted::WeightedIndex}; use crate::{ alloc_aligned, layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, ZnxInfos, - ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo, + ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, }; -#[derive(PartialEq, Eq, Debug, Clone)] +#[repr(C)] +#[derive(PartialEq, Eq, Debug, Clone, Hash)] pub struct ScalarZnx { pub data: D, pub n: usize, pub cols: usize, } +impl DigestU64 for ScalarZnx { + fn digest_u64(&self) -> u64 { + let mut h: DefaultHasher = DefaultHasher::new(); + h.write(self.data.as_ref()); + h.write_usize(self.n); + h.write_usize(self.cols); + h.finish() + } +} + impl ToOwnedDeep for ScalarZnx { type Owned = ScalarZnx>; fn to_owned_deep(&self) -> Self::Owned { @@ -145,8 +158,18 @@ impl ZnxZero for ScalarZnx { } impl FillUniform for ScalarZnx { - fn fill_uniform(&mut self, source: &mut Source) { - source.fill_bytes(self.data.as_mut()); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + match log_bound { + 64 => source.fill_bytes(self.data.as_mut()), + 0 => panic!("invalid log_bound, cannot be zero"), + _ => { + let mask: u64 = (1u64 << log_bound) - 1; + for x in self.raw_mut().iter_mut() { + let r = source.next_u64() & mask; + *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound); + } + } + } } } diff --git a/poulpy-hal/src/layouts/scratch.rs b/poulpy-hal/src/layouts/scratch.rs index 695883a..2d24cfe 100644 --- a/poulpy-hal/src/layouts/scratch.rs +++ b/poulpy-hal/src/layouts/scratch.rs @@ -2,11 +2,13 @@ use std::marker::PhantomData; use crate::layouts::Backend; +#[repr(C)] pub struct ScratchOwned { pub data: Vec, pub _phantom: PhantomData, } +#[repr(C)] pub struct Scratch { pub _phantom: PhantomData, pub data: [u8], diff --git a/poulpy-hal/src/layouts/stats.rs b/poulpy-hal/src/layouts/stats.rs index 3b573f9..05dd087 100644 --- a/poulpy-hal/src/layouts/stats.rs +++ b/poulpy-hal/src/layouts/stats.rs @@ -4,7 +4,7 @@ use rug::{ ops::{AddAssignRound, DivAssignRound, SubAssignRound}, }; -use crate::layouts::{DataRef, VecZnx, ZnxInfos}; +use crate::layouts::{Backend, DataRef, VecZnx, VecZnxBig, VecZnxBigToRef, ZnxInfos}; impl VecZnx { pub fn std(&self, basek: usize, col: usize) -> f64 { @@ -27,3 +27,17 @@ impl VecZnx { std.to_f64() } } + +impl> VecZnxBig { + pub fn std(&self, basek: usize, col: usize) -> f64 { + let self_ref: VecZnxBig<&[u8], B> = self.to_ref(); + let znx: VecZnx<&[u8]> = VecZnx { + data: self_ref.data, + n: self_ref.n, + cols: self_ref.cols, + size: self_ref.size, + max_size: self_ref.max_size, + }; + znx.std(basek, col) + } +} diff --git a/poulpy-hal/src/layouts/svp_ppol.rs b/poulpy-hal/src/layouts/svp_ppol.rs index 80b0e52..428f055 100644 --- a/poulpy-hal/src/layouts/svp_ppol.rs +++ b/poulpy-hal/src/layouts/svp_ppol.rs @@ -1,12 +1,19 @@ -use std::marker::PhantomData; +use std::{ + fmt, + hash::{DefaultHasher, Hasher}, + marker::PhantomData, +}; use crate::{ alloc_aligned, - layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ReaderFrom, WriterTo, ZnxInfos, ZnxSliceSize, ZnxView}, + layouts::{ + Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ReaderFrom, WriterTo, ZnxInfos, ZnxSliceSize, ZnxView, + }, oep::SvpPPolAllocBytesImpl, }; -#[derive(PartialEq, Eq)] +#[repr(C)] +#[derive(PartialEq, Eq, Hash)] pub struct SvpPPol { pub data: D, pub n: usize, @@ -14,6 +21,16 @@ pub struct SvpPPol { pub _phantom: PhantomData, } +impl DigestU64 for SvpPPol { + fn digest_u64(&self) -> u64 { + let mut h: DefaultHasher = DefaultHasher::new(); + h.write(self.data.as_ref()); + h.write_usize(self.n); + h.write_usize(self.cols); + h.finish() + } +} + impl ZnxSliceSize for SvpPPol { fn sl(&self) -> usize { B::layout_prep_word_count() * self.n() @@ -153,3 +170,32 @@ impl WriterTo for SvpPPol { Ok(()) } } + +impl fmt::Display for SvpPPol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "SvpPPol(n={}, cols={})", self.n, self.cols)?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + let coeffs = self.at(col, 0); + write!(f, "[")?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + Ok(()) + } +} diff --git a/poulpy-hal/src/layouts/vec_znx.rs b/poulpy-hal/src/layouts/vec_znx.rs index 6564e35..0ff9454 100644 --- a/poulpy-hal/src/layouts/vec_znx.rs +++ b/poulpy-hal/src/layouts/vec_znx.rs @@ -1,10 +1,13 @@ -use std::fmt; +use std::{ + fmt, + hash::{DefaultHasher, Hasher}, +}; use crate::{ alloc_aligned, layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, ZnxInfos, - ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, + ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, }; @@ -12,7 +15,8 @@ use crate::{ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use rand::RngCore; -#[derive(PartialEq, Eq, Clone, Copy)] +#[repr(C)] +#[derive(PartialEq, Eq, Clone, Copy, Hash)] pub struct VecZnx { pub data: D, pub n: usize, @@ -21,6 +25,18 @@ pub struct VecZnx { pub max_size: usize, } +impl DigestU64 for VecZnx { + fn digest_u64(&self) -> u64 { + let mut h: DefaultHasher = DefaultHasher::new(); + h.write(self.data.as_ref()); + h.write_usize(self.n); + h.write_usize(self.cols); + h.write_usize(self.size); + h.write_usize(self.max_size); + h.finish() + } +} + impl ToOwnedDeep for VecZnx { type Owned = VecZnx>; fn to_owned_deep(&self) -> Self::Owned { @@ -173,8 +189,18 @@ impl fmt::Display for VecZnx { } impl FillUniform for VecZnx { - fn fill_uniform(&mut self, source: &mut Source) { - source.fill_bytes(self.data.as_mut()); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + match log_bound { + 64 => source.fill_bytes(self.data.as_mut()), + 0 => panic!("invalid log_bound, cannot be zero"), + _ => { + let mask: u64 = (1u64 << log_bound) - 1; + for x in self.raw_mut().iter_mut() { + let r = source.next_u64() & mask; + *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound); + } + } + } } } diff --git a/poulpy-hal/src/layouts/vec_znx_big.rs b/poulpy-hal/src/layouts/vec_znx_big.rs index df4a507..ee5e919 100644 --- a/poulpy-hal/src/layouts/vec_znx_big.rs +++ b/poulpy-hal/src/layouts/vec_znx_big.rs @@ -1,15 +1,21 @@ -use std::marker::PhantomData; +use std::{ + hash::{DefaultHasher, Hasher}, + marker::PhantomData, +}; use rand_distr::num_traits::Zero; use std::fmt; use crate::{ alloc_aligned, - layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{ + Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + }, oep::VecZnxBigAllocBytesImpl, }; -#[derive(PartialEq, Eq)] +#[repr(C)] +#[derive(PartialEq, Eq, Hash)] pub struct VecZnxBig { pub data: D, pub n: usize, @@ -19,6 +25,18 @@ pub struct VecZnxBig { pub _phantom: PhantomData, } +impl DigestU64 for VecZnxBig { + fn digest_u64(&self) -> u64 { + let mut h: DefaultHasher = DefaultHasher::new(); + h.write(self.data.as_ref()); + h.write_usize(self.n); + h.write_usize(self.cols); + h.write_usize(self.size); + h.write_usize(self.max_size); + h.finish() + } +} + impl ZnxSliceSize for VecZnxBig { fn sl(&self) -> usize { B::layout_big_word_count() * self.n() * self.cols() diff --git a/poulpy-hal/src/layouts/vec_znx_dft.rs b/poulpy-hal/src/layouts/vec_znx_dft.rs index dbf87a3..027742c 100644 --- a/poulpy-hal/src/layouts/vec_znx_dft.rs +++ b/poulpy-hal/src/layouts/vec_znx_dft.rs @@ -1,14 +1,21 @@ -use std::{fmt, marker::PhantomData}; +use std::{ + fmt, + hash::{DefaultHasher, Hasher}, + marker::PhantomData, +}; use rand_distr::num_traits::Zero; use crate::{ alloc_aligned, layouts::{ - Backend, Data, DataMut, DataRef, DataView, DataViewMut, VecZnxBig, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, VecZnxBig, ZnxInfos, ZnxSliceSize, ZnxView, + ZnxViewMut, ZnxZero, }, - oep::VecZnxBigAllocBytesImpl, + oep::VecZnxDftAllocBytesImpl, }; + +#[repr(C)] #[derive(PartialEq, Eq)] pub struct VecZnxDft { pub data: D, @@ -19,6 +26,18 @@ pub struct VecZnxDft { pub _phantom: PhantomData, } +impl DigestU64 for VecZnxDft { + fn digest_u64(&self) -> u64 { + let mut h: DefaultHasher = DefaultHasher::new(); + h.write(self.data.as_ref()); + h.write_usize(self.n); + h.write_usize(self.cols); + h.write_usize(self.size); + h.write_usize(self.max_size); + h.finish() + } +} + impl ZnxSliceSize for VecZnxDft { fn sl(&self) -> usize { B::layout_prep_word_count() * self.n() * self.cols() @@ -94,10 +113,10 @@ where impl>, B: Backend> VecZnxDft where - B: VecZnxBigAllocBytesImpl, + B: VecZnxDftAllocBytesImpl, { pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(B::vec_znx_big_alloc_bytes_impl(n, cols, size)); + let data: Vec = alloc_aligned::(B::vec_znx_dft_alloc_bytes_impl(n, cols, size)); Self { data: data.into(), n, @@ -110,7 +129,7 @@ where pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == B::vec_znx_big_alloc_bytes_impl(n, cols, size)); + assert!(data.len() == B::vec_znx_dft_alloc_bytes_impl(n, cols, size)); Self { data: data.into(), n, diff --git a/poulpy-hal/src/layouts/vmp_pmat.rs b/poulpy-hal/src/layouts/vmp_pmat.rs index cadbc58..ce83458 100644 --- a/poulpy-hal/src/layouts/vmp_pmat.rs +++ b/poulpy-hal/src/layouts/vmp_pmat.rs @@ -1,12 +1,16 @@ -use std::marker::PhantomData; +use std::{ + hash::{DefaultHasher, Hasher}, + marker::PhantomData, +}; use crate::{ alloc_aligned, - layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ZnxInfos, ZnxView}, + layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, ZnxInfos, ZnxView}, oep::VmpPMatAllocBytesImpl, }; -#[derive(PartialEq, Eq)] +#[repr(C)] +#[derive(PartialEq, Eq, Hash)] pub struct VmpPMat { data: D, n: usize, @@ -17,6 +21,19 @@ pub struct VmpPMat { _phantom: PhantomData, } +impl DigestU64 for VmpPMat { + fn digest_u64(&self) -> u64 { + let mut h: DefaultHasher = DefaultHasher::new(); + h.write(self.data.as_ref()); + h.write_usize(self.n); + h.write_usize(self.size); + h.write_usize(self.rows); + h.write_usize(self.cols_in); + h.write_usize(self.cols_out); + h.finish() + } +} + impl ZnxView for VmpPMat { type Scalar = B::ScalarPrep; } @@ -37,6 +54,10 @@ impl ZnxInfos for VmpPMat { fn size(&self) -> usize { self.size } + + fn poly_count(&self) -> usize { + self.rows() * self.cols_in() * self.size() * self.cols_out() + } } impl DataView for VmpPMat { diff --git a/poulpy-hal/src/layouts/zn.rs b/poulpy-hal/src/layouts/zn.rs index 68b0b0f..7807034 100644 --- a/poulpy-hal/src/layouts/zn.rs +++ b/poulpy-hal/src/layouts/zn.rs @@ -1,10 +1,13 @@ -use std::fmt; +use std::{ + fmt, + hash::{DefaultHasher, Hasher}, +}; use crate::{ alloc_aligned, layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, ZnxInfos, - ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, + Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo, + ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, }, source::Source, }; @@ -12,7 +15,8 @@ use crate::{ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use rand::RngCore; -#[derive(PartialEq, Eq, Clone, Copy)] +#[repr(C)] +#[derive(PartialEq, Eq, Clone, Copy, Hash)] pub struct Zn { pub data: D, pub n: usize, @@ -21,6 +25,18 @@ pub struct Zn { pub max_size: usize, } +impl DigestU64 for Zn { + fn digest_u64(&self) -> u64 { + let mut h: DefaultHasher = DefaultHasher::new(); + h.write(self.data.as_ref()); + h.write_usize(self.n); + h.write_usize(self.cols); + h.write_usize(self.size); + h.write_usize(self.max_size); + h.finish() + } +} + impl ToOwnedDeep for Zn { type Owned = Zn>; fn to_owned_deep(&self) -> Self::Owned { @@ -173,8 +189,18 @@ impl fmt::Display for Zn { } impl FillUniform for Zn { - fn fill_uniform(&mut self, source: &mut Source) { - source.fill_bytes(self.data.as_mut()); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + match log_bound { + 64 => source.fill_bytes(self.data.as_mut()), + 0 => panic!("invalid log_bound, cannot be zero"), + _ => { + let mask: u64 = (1u64 << log_bound) - 1; + for x in self.raw_mut().iter_mut() { + let r = source.next_u64() & mask; + *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound); + } + } + } } } diff --git a/poulpy-hal/src/layouts/znx_base.rs b/poulpy-hal/src/layouts/znx_base.rs index deab5ea..6173daf 100644 --- a/poulpy-hal/src/layouts/znx_base.rs +++ b/poulpy-hal/src/layouts/znx_base.rs @@ -117,7 +117,7 @@ where } pub trait FillUniform { - fn fill_uniform(&mut self, source: &mut Source); + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source); } pub trait Reset { diff --git a/poulpy-hal/src/lib.rs b/poulpy-hal/src/lib.rs index 2199c9f..ca09126 100644 --- a/poulpy-hal/src/lib.rs +++ b/poulpy-hal/src/lib.rs @@ -4,11 +4,13 @@ #![feature(trait_alias)] pub mod api; +pub mod bench_suite; pub mod delegates; pub mod layouts; pub mod oep; +pub mod reference; pub mod source; -pub mod tests; +pub mod test_suite; pub mod doc { #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/backend_safety_contract.md"))] @@ -85,13 +87,20 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { /// Allocates a block of T aligned with [DEFAULTALIGN]. /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { - assert_eq!( - (size * size_of::()) % (align / size_of::()), - 0, - "size={} must be a multiple of align={}", - size, + assert!( + align.is_power_of_two(), + "Alignment must be a power of two but is {}", align ); + + assert_eq!( + (size * size_of::()) % align, + 0, + "size*size_of::()={} must be a multiple of align={}", + size * size_of::(), + align + ); + let mut vec_u8: Vec = alloc_aligned_custom_u8(size_of::() * size, align); let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T; let len: usize = vec_u8.len() / size_of::(); @@ -100,11 +109,11 @@ pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { unsafe { Vec::from_raw_parts(ptr, len, cap) } } -/// Allocates an aligned vector of size equal to the smallest multiple -/// of [DEFAULTALIGN]/`size_of::`() that is equal or greater to `size`. +/// Allocates an aligned vector of the given size. +/// Padds until it is size in [u8] a multiple of [DEFAULTALIGN]. pub fn alloc_aligned(size: usize) -> Vec { alloc_aligned_custom::( - size + (DEFAULTALIGN - (size % (DEFAULTALIGN / size_of::()))) % DEFAULTALIGN, + (size * size_of::()).next_multiple_of(DEFAULTALIGN) / size_of::(), DEFAULTALIGN, ) } diff --git a/poulpy-hal/src/oep/svp_ppol.rs b/poulpy-hal/src/oep/svp_ppol.rs index 81668ca..b50208a 100644 --- a/poulpy-hal/src/oep/svp_ppol.rs +++ b/poulpy-hal/src/oep/svp_ppol.rs @@ -1,4 +1,6 @@ -use crate::layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef}; +use crate::layouts::{ + Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, +}; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See TODO for reference code. @@ -39,9 +41,28 @@ pub unsafe trait SvpPrepareImpl { /// * See TODO for reference code. /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait SvpApplyImpl { - fn svp_apply_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) +pub unsafe trait SvpApplyDftImpl { + fn svp_apply_dft_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize) where + R: VecZnxDftToMut, + A: SvpPPolToRef, + C: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO for reference code. +/// * See TODO for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait SvpApplyDftToDftImpl { + fn svp_apply_dft_to_dft_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &C, + b_col: usize, + ) where R: VecZnxDftToMut, A: SvpPPolToRef, C: VecZnxDftToRef; @@ -51,8 +72,27 @@ pub unsafe trait SvpApplyImpl { /// * See TODO for reference code. /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait SvpApplyInplaceImpl: Backend { - fn svp_apply_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) +pub unsafe trait SvpApplyDftToDftAddImpl { + fn svp_apply_dft_to_dft_add_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &C, + b_col: usize, + ) where + R: VecZnxDftToMut, + A: SvpPPolToRef, + C: VecZnxDftToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO for reference code. +/// * See TODO for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait SvpApplyDftToDftInplaceImpl: Backend { + fn svp_apply_dft_to_dft_inplace_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxDftToMut, A: SvpPPolToRef; diff --git a/poulpy-hal/src/oep/vec_znx.rs b/poulpy-hal/src/oep/vec_znx.rs index ddfe6fe..268e46b 100644 --- a/poulpy-hal/src/oep/vec_znx.rs +++ b/poulpy-hal/src/oep/vec_znx.rs @@ -1,5 +1,3 @@ -use rand_distr::Distribution; - use crate::{ layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, source::Source, @@ -64,6 +62,28 @@ pub unsafe trait VecZnxAddInplaceImpl { A: VecZnxToRef; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. +/// * See [crate::api::VecZnxAddScalar] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAddScalarImpl { + /// Adds the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`. + #[allow(clippy::too_many_arguments)] + fn vec_znx_add_scalar_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + b_limb: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code. /// * See [crate::api::VecZnxAddScalarInplace] for corresponding public API. @@ -115,6 +135,28 @@ pub unsafe trait VecZnxSubBAInplaceImpl { A: VecZnxToRef; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO. +/// * See [crate::api::VecZnxAddScalar] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxSubScalarImpl { + /// Adds the selected column of `a` on the selected column and limb of `b` and writes the result on the selected column of `res`. + #[allow(clippy::too_many_arguments)] + fn vec_znx_sub_scalar_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + b_limb: usize, + ) where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code. /// * See [crate::api::VecZnxSubScalarInplace] for corresponding public API. @@ -153,14 +195,76 @@ pub unsafe trait VecZnxNegateInplaceImpl { A: VecZnxToMut; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_tmp_bytes] for reference code. +/// * See [crate::api::VecZnxRshTmpBytes] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxRshTmpBytesImpl { + fn vec_znx_rsh_tmp_bytes_impl(module: &Module) -> usize; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_inplace] for reference code. +/// * See [crate::api::VecZnxRsh] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxRshImpl { + #[allow(clippy::too_many_arguments)] + fn vec_znx_rsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_tmp_bytes] for reference code. +/// * See [crate::api::VecZnxLshTmpBytes] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxLshTmpBytesImpl { + fn vec_znx_lsh_tmp_bytes_impl(module: &Module) -> usize; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_inplace] for reference code. +/// * See [crate::api::VecZnxLsh] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxLshImpl { + #[allow(clippy::too_many_arguments)] + fn vec_znx_lsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxToRef; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [crate::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code. /// * See [crate::api::VecZnxRshInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRshInplaceImpl { - fn vec_znx_rsh_inplace_impl(module: &Module, basek: usize, k: usize, a: &mut A) - where - A: VecZnxToMut; + fn vec_znx_rsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) @@ -168,9 +272,15 @@ pub unsafe trait VecZnxRshInplaceImpl { /// * See [crate::api::VecZnxLshInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxLshInplaceImpl { - fn vec_znx_lsh_inplace_impl(module: &Module, basek: usize, k: usize, a: &mut A) - where - A: VecZnxToMut; + fn vec_znx_lsh_inplace_impl( + module: &Module, + basek: usize, + k: usize, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) @@ -184,12 +294,20 @@ pub unsafe trait VecZnxRotateImpl { A: VecZnxToRef; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO; +/// * See [crate::api::VecZnxRotateInplaceTmpBytes] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxRotateInplaceTmpBytesImpl { + fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module) -> usize; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code. /// * See [crate::api::VecZnxRotateInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxRotateInplaceImpl { - fn vec_znx_rotate_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + fn vec_znx_rotate_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxToMut; } @@ -199,20 +317,28 @@ pub unsafe trait VecZnxRotateInplaceImpl { /// * See [crate::api::VecZnxAutomorphism] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAutomorphismImpl { - fn vec_znx_automorphism_impl(module: &Module, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + fn vec_znx_automorphism_impl(module: &Module, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) where R: VecZnxToMut, A: VecZnxToRef; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO; +/// * See [crate::api::VecZnxAutomorphismInplaceTmpBytes] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxAutomorphismInplaceTmpBytesImpl { + fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module) -> usize; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code. /// * See [crate::api::VecZnxAutomorphismInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxAutomorphismInplaceImpl { - fn vec_znx_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + fn vec_znx_automorphism_inplace_impl(module: &Module, k: i64, res: &mut R, res_col: usize, scratch: &mut Scratch) where - A: VecZnxToMut; + R: VecZnxToMut; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) @@ -226,34 +352,75 @@ pub unsafe trait VecZnxMulXpMinusOneImpl { A: VecZnxToRef; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO; +/// * See [crate::api::VecZnxMulXpMinusOneInplaceTmpBytes] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxMulXpMinusOneInplaceTmpBytesImpl { + fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module) -> usize; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code. /// * See [crate::api::VecZnxMulXpMinusOneInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxMulXpMinusOneInplaceImpl { - fn vec_znx_mul_xp_minus_one_inplace_impl(module: &Module, p: i64, res: &mut R, res_col: usize) - where + fn vec_znx_mul_xp_minus_one_inplace_impl( + module: &Module, + p: i64, + res: &mut R, + res_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxToMut; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code. -/// * See [crate::api::VecZnxSplit] for corresponding public API. +/// * See TODO; +/// * See [crate::api::VecZnxSplitRingTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxSplitImpl { - fn vec_znx_split_impl(module: &Module, res: &mut [R], res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where +pub unsafe trait VecZnxSplitRingTmpBytesImpl { + fn vec_znx_split_ring_tmp_bytes_impl(module: &Module) -> usize; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::cpu_spqlios::vec_znx::vec_znx_split_ref] for reference code. +/// * See [crate::api::VecZnxSplitRing] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxSplitRingImpl { + fn vec_znx_split_ring_impl( + module: &Module, + res: &mut [R], + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxToMut, A: VecZnxToRef; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO; +/// * See [crate::api::VecZnxMergeRingsTmpBytes] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxMergeRingsTmpBytesImpl { + fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module) -> usize; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [crate::cpu_spqlios::vec_znx::vec_znx_merge_ref] for reference code. /// * See [crate::api::VecZnxMerge] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxMergeImpl { - fn vec_znx_merge_impl(module: &Module, res: &mut R, res_col: usize, a: &[A], a_col: usize) - where +pub unsafe trait VecZnxMergeRingsImpl { + fn vec_znx_merge_rings_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &[A], + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxToMut, A: VecZnxToRef; } @@ -262,8 +429,8 @@ pub unsafe trait VecZnxMergeImpl { /// * See [crate::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code. /// * See [crate::api::VecZnxSwithcDegree] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxSwithcDegreeImpl { - fn vec_znx_switch_degree_impl( +pub unsafe trait VecZnxSwitchRingImpl { + fn vec_znx_switch_ring_impl( module: &Module, res: &mut R, res_col: usize, @@ -287,47 +454,11 @@ pub unsafe trait VecZnxCopyImpl { /// * See [crate::api::VecZnxFillUniform] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxFillUniformImpl { - fn vec_znx_fill_uniform_impl(module: &Module, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + fn vec_znx_fill_uniform_impl(module: &Module, basek: usize, res: &mut R, res_col: usize, source: &mut Source) where R: VecZnxToMut; } -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::api::VecZnxFillDistF64] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxFillDistF64Impl { - fn vec_znx_fill_dist_f64_impl>( - module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: VecZnxToMut; -} - -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::api::VecZnxAddDistF64] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxAddDistF64Impl { - fn vec_znx_add_dist_f64_impl>( - module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: VecZnxToMut; -} - #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [crate::api::VecZnxFillNormal] for corresponding public API. diff --git a/poulpy-hal/src/oep/vec_znx_big.rs b/poulpy-hal/src/oep/vec_znx_big.rs index 2764ef2..8398983 100644 --- a/poulpy-hal/src/oep/vec_znx_big.rs +++ b/poulpy-hal/src/oep/vec_znx_big.rs @@ -1,10 +1,19 @@ -use rand_distr::Distribution; - use crate::{ layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef}, source::Source, }; +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO for reference code. +/// * See TODO for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxBigFromSmallImpl { + fn vec_znx_big_from_small_impl(res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See TODO for reference code. /// * See TODO for corresponding public API. @@ -47,60 +56,6 @@ pub unsafe trait VecZnxBigAddNormalImpl { ); } -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxBigFillNormalImpl { - fn fill_normal_impl>( - module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ); -} - -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxBigFillDistF64Impl { - fn fill_dist_f64_impl, D: Distribution>( - module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ); -} - -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See TODO for reference code. -/// * See TODO for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxBigAddDistF64Impl { - fn add_dist_f64_impl, D: Distribution>( - module: &Module, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ); -} - /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See TODO for reference code. /// * See TODO for corresponding public API. @@ -248,6 +203,17 @@ pub unsafe trait VecZnxBigSubSmallBInplaceImpl { A: VecZnxToRef; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO for reference code. +/// * See TODO for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxBigNegateImpl { + fn vec_znx_big_negate_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See TODO for reference code. /// * See TODO for corresponding public API. @@ -295,12 +261,20 @@ pub unsafe trait VecZnxBigAutomorphismImpl { A: VecZnxBigToRef; } +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO for reference code. +/// * See TODO for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxBigAutomorphismInplaceTmpBytesImpl { + fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module) -> usize; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See TODO for reference code. /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigAutomorphismInplaceImpl { - fn vec_znx_big_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize) + fn vec_znx_big_automorphism_inplace_impl(module: &Module, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch) where A: VecZnxBigToMut; } diff --git a/poulpy-hal/src/oep/vec_znx_dft.rs b/poulpy-hal/src/oep/vec_znx_dft.rs index 9f55887..c4fd46a 100644 --- a/poulpy-hal/src/oep/vec_znx_dft.rs +++ b/poulpy-hal/src/oep/vec_znx_dft.rs @@ -23,9 +23,16 @@ pub unsafe trait VecZnxDftFromBytesImpl { /// * See TODO for reference code. /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait DFTImpl { - fn dft_impl(module: &Module, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) - where +pub unsafe trait VecZnxDftApplyImpl { + fn vec_znx_dft_apply_impl( + module: &Module, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + ) where R: VecZnxDftToMut, A: VecZnxToRef; } @@ -42,17 +49,23 @@ pub unsafe trait VecZnxDftAllocBytesImpl { /// * See TODO for reference code. /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VecZnxIDFTTmpBytesImpl { - fn vec_znx_idft_tmp_bytes_impl(module: &Module) -> usize; +pub unsafe trait VecZnxIdftApplyTmpBytesImpl { + fn vec_znx_idft_apply_tmp_bytes_impl(module: &Module) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See TODO for reference code. /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait IDFTImpl { - fn idft_impl(module: &Module, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) - where +pub unsafe trait VecZnxIdftApplyImpl { + fn vec_znx_idft_apply_impl( + module: &Module, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where R: VecZnxBigToMut, A: VecZnxDftToRef; } @@ -61,8 +74,8 @@ pub unsafe trait IDFTImpl { /// * See TODO for reference code. /// * See for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait IDFTTmpAImpl { - fn idft_tmp_a_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) +pub unsafe trait VecZnxIdftApplyTmpAImpl { + fn vec_znx_idft_apply_tmpa_impl(module: &Module, res: &mut R, res_col: usize, a: &mut A, a_col: usize) where R: VecZnxBigToMut, A: VecZnxDftToMut; @@ -72,8 +85,8 @@ pub unsafe trait IDFTTmpAImpl { /// * See TODO for reference code. /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait IDFTConsumeImpl { - fn idft_consume_impl(module: &Module, a: VecZnxDft) -> VecZnxBig +pub unsafe trait VecZnxIdftApplyConsumeImpl { + fn vec_znx_idft_apply_consume_impl(module: &Module, a: VecZnxDft) -> VecZnxBig where VecZnxDft: VecZnxDftToMut; } diff --git a/poulpy-hal/src/oep/vmp_pmat.rs b/poulpy-hal/src/oep/vmp_pmat.rs index a81671b..74d8cd2 100644 --- a/poulpy-hal/src/oep/vmp_pmat.rs +++ b/poulpy-hal/src/oep/vmp_pmat.rs @@ -1,5 +1,5 @@ use crate::layouts::{ - Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, + Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef, }; /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) @@ -45,13 +45,42 @@ pub unsafe trait VmpPrepareTmpBytesImpl { /// * See TODO for reference code. /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait VmpPMatPrepareImpl { +pub unsafe trait VmpPrepareImpl { fn vmp_prepare_impl(module: &Module, res: &mut R, a: &A, scratch: &mut Scratch) where R: VmpPMatToMut, A: MatZnxToRef; } +#[allow(clippy::too_many_arguments)] +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO for reference code. +/// * See TODO for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VmpApplyDftTmpBytesImpl { + fn vmp_apply_dft_tmp_bytes_impl( + module: &Module, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO for reference code. +/// * See TODO for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VmpApplyDftImpl { + fn vmp_apply_dft_impl(module: &Module, res: &mut R, a: &A, b: &C, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxToRef, + C: VmpPMatToRef; +} + #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See TODO for reference code. diff --git a/poulpy-hal/src/oep/zn.rs b/poulpy-hal/src/oep/zn.rs index 4a35185..2a1122a 100644 --- a/poulpy-hal/src/oep/zn.rs +++ b/poulpy-hal/src/oep/zn.rs @@ -1,65 +1,35 @@ -use rand_distr::Distribution; - use crate::{ layouts::{Backend, Scratch, ZnToMut}, source::Source, }; +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO +/// * See [crate::api::ZnNormalizeTmpBytes] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait ZnNormalizeTmpBytesImpl { + fn zn_normalize_tmp_bytes_impl(n: usize) -> usize; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [zn_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/zn64.c#L9) for reference code. -/// * See [crate::api::ZnxNormalizeInplace] for corresponding public API. +/// * See [crate::api::ZnNormalizeInplace] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ZnNormalizeInplaceImpl { - fn zn_normalize_inplace_impl(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + fn zn_normalize_inplace_impl(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) where - A: ZnToMut; + R: ZnToMut; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [crate::api::ZnFillUniform] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait ZnFillUniformImpl { - fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) where R: ZnToMut; } -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::api::ZnFillDistF64] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnFillDistF64Impl { - fn zn_fill_dist_f64_impl>( - n: usize, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [crate::api::ZnAddDistF64] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnAddDistF64Impl { - fn zn_add_dist_f64_impl>( - n: usize, - basek: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - dist: D, - bound: f64, - ) where - R: ZnToMut; -} - #[allow(clippy::too_many_arguments)] /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [crate::api::ZnFillNormal] for corresponding public API. diff --git a/poulpy-hal/src/reference/fft64/mod.rs b/poulpy-hal/src/reference/fft64/mod.rs new file mode 100644 index 0000000..a1cf49a --- /dev/null +++ b/poulpy-hal/src/reference/fft64/mod.rs @@ -0,0 +1,24 @@ +pub mod reim; +pub mod reim4; +pub mod svp; +pub mod vec_znx_big; +pub mod vec_znx_dft; +pub mod vmp; + +pub(crate) fn assert_approx_eq_slice(a: &[f64], b: &[f64], tol: f64) { + assert_eq!(a.len(), b.len(), "Slices have different lengths"); + + for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() { + let diff: f64 = (x - y).abs(); + let scale: f64 = x.abs().max(y.abs()).max(1.0); + assert!( + diff <= tol * scale, + "Difference at index {}: left={} right={} rel_diff={} > tol={}", + i, + x, + y, + diff / scale, + tol + ); + } +} diff --git a/poulpy-hal/src/reference/fft64/reim/conversion.rs b/poulpy-hal/src/reference/fft64/reim/conversion.rs new file mode 100644 index 0000000..c76751a --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim/conversion.rs @@ -0,0 +1,31 @@ +#[inline(always)] +pub fn reim_from_znx_i64_ref(res: &mut [f64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()) + } + + for i in 0..res.len() { + res[i] = a[i] as f64 + } +} + +#[inline(always)] +pub fn reim_to_znx_i64_ref(res: &mut [i64], divisor: f64, a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()) + } + let inv_div = 1. / divisor; + for i in 0..res.len() { + res[i] = (a[i] * inv_div).round() as i64 + } +} + +#[inline(always)] +pub fn reim_to_znx_i64_inplace_ref(res: &mut [f64], divisor: f64) { + let inv_div = 1. / divisor; + for ri in res { + *ri = f64::from_bits(((*ri * inv_div).round() as i64) as u64) + } +} diff --git a/poulpy-hal/src/reference/fft64/reim/fft_ref.rs b/poulpy-hal/src/reference/fft64/reim/fft_ref.rs new file mode 100644 index 0000000..849a58e --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim/fft_ref.rs @@ -0,0 +1,327 @@ +use std::fmt::Debug; + +use rand_distr::num_traits::{Float, FloatConst}; + +use crate::reference::fft64::reim::{as_arr, as_arr_mut}; + +#[inline(always)] +pub fn fft_ref(m: usize, omg: &[R], data: &mut [R]) { + assert!(data.len() == 2 * m); + let (re, im) = data.split_at_mut(m); + + if m <= 16 { + match m { + 1 => {} + 2 => fft2_ref( + as_arr_mut::<2, R>(re), + as_arr_mut::<2, R>(im), + *as_arr::<2, R>(omg), + ), + 4 => fft4_ref( + as_arr_mut::<4, R>(re), + as_arr_mut::<4, R>(im), + *as_arr::<4, R>(omg), + ), + 8 => fft8_ref( + as_arr_mut::<8, R>(re), + as_arr_mut::<8, R>(im), + *as_arr::<8, R>(omg), + ), + 16 => fft16_ref( + as_arr_mut::<16, R>(re), + as_arr_mut::<16, R>(im), + *as_arr::<16, R>(omg), + ), + _ => {} + } + } else if m <= 2048 { + fft_bfs_16_ref(m, re, im, omg, 0); + } else { + fft_rec_16_ref(m, re, im, omg, 0); + } +} + +#[inline(always)] +fn fft_rec_16_ref(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize { + if m <= 2048 { + return fft_bfs_16_ref(m, re, im, omg, pos); + }; + + let h = m >> 1; + twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..])); + pos += 2; + pos = fft_rec_16_ref(h, re, im, omg, pos); + pos = fft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos); + pos +} + +#[inline(always)] +fn cplx_twiddle(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) { + let dr: R = *rb * omg_re - *ib * omg_im; + let di: R = *rb * omg_im + *ib * omg_re; + *rb = *ra - dr; + *ib = *ia - di; + *ra = *ra + dr; + *ia = *ia + di; +} + +#[inline(always)] +fn cplx_i_twiddle(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) { + let dr: R = *rb * omg_im + *ib * omg_re; + let di: R = *rb * omg_re - *ib * omg_im; + *rb = *ra + dr; + *ib = *ia - di; + *ra = *ra - dr; + *ia = *ia + di; +} + +#[inline(always)] +fn fft2_ref(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) { + let [ra, rb] = re; + let [ia, ib] = im; + let [romg, iomg] = omg; + cplx_twiddle(ra, ia, rb, ib, romg, iomg); +} + +#[inline(always)] +fn fft4_ref(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) { + let [re_0, re_1, re_2, re_3] = re; + let [im_0, im_1, im_2, im_3] = im; + + { + let omg_0 = omg[0]; + let omg_1 = omg[1]; + cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1); + cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1); + } + + { + let omg_0 = omg[2]; + let omg_1 = omg[3]; + cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1); + cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_1); + } +} + +#[inline(always)] +fn fft8_ref(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) { + let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re; + let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im; + + { + let omg_0 = omg[0]; + let omg_1 = omg[1]; + cplx_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1); + cplx_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1); + cplx_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1); + cplx_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1); + } + + { + let omg_2 = omg[2]; + let omg_3 = omg[3]; + cplx_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3); + cplx_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3); + cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_2, omg_3); + cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_2, omg_3); + } + + { + let omg_4 = omg[4]; + let omg_5 = omg[5]; + let omg_6 = omg[6]; + let omg_7 = omg[7]; + cplx_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6); + cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_4, omg_6); + cplx_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7); + cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_5, omg_7); + } +} + +#[inline(always)] +fn fft16_ref(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) { + let [ + re_0, + re_1, + re_2, + re_3, + re_4, + re_5, + re_6, + re_7, + re_8, + re_9, + re_10, + re_11, + re_12, + re_13, + re_14, + re_15, + ] = re; + let [ + im_0, + im_1, + im_2, + im_3, + im_4, + im_5, + im_6, + im_7, + im_8, + im_9, + im_10, + im_11, + im_12, + im_13, + im_14, + im_15, + ] = im; + + { + let omg_0: R = omg[0]; + let omg_1: R = omg[1]; + cplx_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1); + cplx_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1); + cplx_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1); + cplx_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1); + + cplx_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1); + cplx_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1); + cplx_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1); + cplx_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1); + } + + { + let omg_2: R = omg[2]; + let omg_3: R = omg[3]; + cplx_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3); + cplx_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3); + cplx_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3); + cplx_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3); + + cplx_i_twiddle(re_8, im_8, re_12, im_12, omg_2, omg_3); + cplx_i_twiddle(re_9, im_9, re_13, im_13, omg_2, omg_3); + cplx_i_twiddle(re_10, im_10, re_14, im_14, omg_2, omg_3); + cplx_i_twiddle(re_11, im_11, re_15, im_15, omg_2, omg_3); + } + + { + let omg_0: R = omg[4]; + let omg_1: R = omg[5]; + let omg_2: R = omg[6]; + let omg_3: R = omg[7]; + cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1); + cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1); + cplx_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3); + cplx_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3); + + cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_0, omg_1); + cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_0, omg_1); + cplx_i_twiddle(re_12, im_12, re_14, im_14, omg_2, omg_3); + cplx_i_twiddle(re_13, im_13, re_15, im_15, omg_2, omg_3); + } + + { + let omg_0: R = omg[8]; + let omg_1: R = omg[9]; + let omg_2: R = omg[10]; + let omg_3: R = omg[11]; + let omg_4: R = omg[12]; + let omg_5: R = omg[13]; + let omg_6: R = omg[14]; + let omg_7: R = omg[15]; + cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4); + cplx_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5); + cplx_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6); + cplx_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7); + + cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_4); + cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_1, omg_5); + cplx_i_twiddle(re_10, im_10, re_11, im_11, omg_2, omg_6); + cplx_i_twiddle(re_14, im_14, re_15, im_15, omg_3, omg_7); + } +} + +#[inline(always)] +fn fft_bfs_16_ref(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize { + let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize; + let mut mm: usize = m; + + if !log_m.is_multiple_of(2) { + let h: usize = mm >> 1; + twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..])); + pos += 2; + mm = h + } + + while mm > 16 { + let h: usize = mm >> 2; + for off in (0..m).step_by(mm) { + bitwiddle_fft_ref( + h, + &mut re[off..], + &mut im[off..], + as_arr::<4, R>(&omg[pos..]), + ); + pos += 4; + } + mm = h + } + + for off in (0..m).step_by(16) { + fft16_ref( + as_arr_mut::<16, R>(&mut re[off..]), + as_arr_mut::<16, R>(&mut im[off..]), + *as_arr::<16, R>(&omg[pos..]), + ); + pos += 16; + } + + pos +} + +#[inline(always)] +fn twiddle_fft_ref(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) { + let romg = omg[0]; + let iomg = omg[1]; + + let (re_lhs, re_rhs) = re.split_at_mut(h); + let (im_lhs, im_rhs) = im.split_at_mut(h); + + for i in 0..h { + cplx_twiddle( + &mut re_lhs[i], + &mut im_lhs[i], + &mut re_rhs[i], + &mut im_rhs[i], + romg, + iomg, + ); + } +} + +#[inline(always)] +fn bitwiddle_fft_ref(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) { + let (r0, r2) = re.split_at_mut(2 * h); + let (r0, r1) = r0.split_at_mut(h); + let (r2, r3) = r2.split_at_mut(h); + + let (i0, i2) = im.split_at_mut(2 * h); + let (i0, i1) = i0.split_at_mut(h); + let (i2, i3) = i2.split_at_mut(h); + + let omg_0: R = omg[0]; + let omg_1: R = omg[1]; + let omg_2: R = omg[2]; + let omg_3: R = omg[3]; + + for i in 0..h { + cplx_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_0, omg_1); + cplx_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_0, omg_1); + } + + for i in 0..h { + cplx_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_2, omg_3); + cplx_i_twiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_2, omg_3); + } +} diff --git a/poulpy-hal/src/reference/fft64/reim/fft_vec.rs b/poulpy-hal/src/reference/fft64/reim/fft_vec.rs new file mode 100644 index 0000000..63b4a80 --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim/fft_vec.rs @@ -0,0 +1,156 @@ +#[inline(always)] +pub fn reim_add_ref(res: &mut [f64], a: &[f64], b: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + assert_eq!(b.len(), res.len()); + } + + for i in 0..res.len() { + res[i] = a[i] + b[i] + } +} + +#[inline(always)] +pub fn reim_add_inplace_ref(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + for i in 0..res.len() { + res[i] += a[i] + } +} + +#[inline(always)] +pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + assert_eq!(b.len(), res.len()); + } + + for i in 0..res.len() { + res[i] = a[i] - b[i] + } +} + +#[inline(always)] +pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + for i in 0..res.len() { + res[i] -= a[i] + } +} + +#[inline(always)] +pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + for i in 0..res.len() { + res[i] = a[i] - res[i] + } +} + +#[inline(always)] +pub fn reim_negate_ref(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + for i in 0..res.len() { + res[i] = -a[i] + } +} + +#[inline(always)] +pub fn reim_negate_inplace_ref(res: &mut [f64]) { + for ri in res { + *ri = -*ri + } +} + +#[inline(always)] +pub fn reim_addmul_ref(res: &mut [f64], a: &[f64], b: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + assert_eq!(b.len(), res.len()); + } + + let m: usize = res.len() >> 1; + + let (rr, ri) = res.split_at_mut(m); + let (ar, ai) = a.split_at(m); + let (br, bi) = b.split_at(m); + + for i in 0..m { + let _ar: f64 = ar[i]; + let _ai: f64 = ai[i]; + let _br: f64 = br[i]; + let _bi: f64 = bi[i]; + let _rr: f64 = _ar * _br - _ai * _bi; + let _ri: f64 = _ar * _bi + _ai * _br; + rr[i] += _rr; + ri[i] += _ri; + } +} + +#[inline(always)] +pub fn reim_mul_inplace_ref(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + } + + let m: usize = res.len() >> 1; + + let (rr, ri) = res.split_at_mut(m); + let (ar, ai) = a.split_at(m); + + for i in 0..m { + let _ar: f64 = ar[i]; + let _ai: f64 = ai[i]; + let _br: f64 = rr[i]; + let _bi: f64 = ri[i]; + let _rr: f64 = _ar * _br - _ai * _bi; + let _ri: f64 = _ar * _bi + _ai * _br; + rr[i] = _rr; + ri[i] = _ri; + } +} + +#[inline(always)] +pub fn reim_mul_ref(res: &mut [f64], a: &[f64], b: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(a.len(), res.len()); + assert_eq!(b.len(), res.len()); + } + + let m: usize = res.len() >> 1; + + let (rr, ri) = res.split_at_mut(m); + let (ar, ai) = a.split_at(m); + let (br, bi) = b.split_at(m); + + for i in 0..m { + let _ar: f64 = ar[i]; + let _ai: f64 = ai[i]; + let _br: f64 = br[i]; + let _bi: f64 = bi[i]; + let _rr: f64 = _ar * _br - _ai * _bi; + let _ri: f64 = _ar * _bi + _ai * _br; + rr[i] = _rr; + ri[i] = _ri; + } +} diff --git a/poulpy-hal/src/reference/fft64/reim/ifft_ref.rs b/poulpy-hal/src/reference/fft64/reim/ifft_ref.rs new file mode 100644 index 0000000..e0fe8a2 --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim/ifft_ref.rs @@ -0,0 +1,322 @@ +use std::fmt::Debug; + +use rand_distr::num_traits::{Float, FloatConst}; + +use crate::reference::fft64::reim::{as_arr, as_arr_mut}; + +pub fn ifft_ref(m: usize, omg: &[R], data: &mut [R]) { + assert!(data.len() == 2 * m); + let (re, im) = data.split_at_mut(m); + + if m <= 16 { + match m { + 1 => {} + 2 => ifft2_ref( + as_arr_mut::<2, R>(re), + as_arr_mut::<2, R>(im), + *as_arr::<2, R>(omg), + ), + 4 => ifft4_ref( + as_arr_mut::<4, R>(re), + as_arr_mut::<4, R>(im), + *as_arr::<4, R>(omg), + ), + 8 => ifft8_ref( + as_arr_mut::<8, R>(re), + as_arr_mut::<8, R>(im), + *as_arr::<8, R>(omg), + ), + 16 => ifft16_ref( + as_arr_mut::<16, R>(re), + as_arr_mut::<16, R>(im), + *as_arr::<16, R>(omg), + ), + _ => {} + } + } else if m <= 2048 { + ifft_bfs_16_ref(m, re, im, omg, 0); + } else { + ifft_rec_16_ref(m, re, im, omg, 0); + } +} + +#[inline(always)] +fn ifft_rec_16_ref(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize { + if m <= 2048 { + return ifft_bfs_16_ref(m, re, im, omg, pos); + }; + let h: usize = m >> 1; + pos = ifft_rec_16_ref(h, re, im, omg, pos); + pos = ifft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos); + inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..])); + pos += 2; + pos +} + +#[inline(always)] +fn ifft_bfs_16_ref(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize { + let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize; + + for off in (0..m).step_by(16) { + ifft16_ref( + as_arr_mut::<16, R>(&mut re[off..]), + as_arr_mut::<16, R>(&mut im[off..]), + *as_arr::<16, R>(&omg[pos..]), + ); + pos += 16; + } + + let mut h: usize = 16; + let m_half: usize = m >> 1; + + while h < m_half { + let mm: usize = h << 2; + for off in (0..m).step_by(mm) { + inv_bitwiddle_ifft_ref( + h, + &mut re[off..], + &mut im[off..], + as_arr::<4, R>(&omg[pos..]), + ); + pos += 4; + } + h = mm; + } + + if !log_m.is_multiple_of(2) { + inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..])); + pos += 2; + } + + pos +} + +#[inline(always)] +fn inv_twiddle(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) { + let r_diff: R = *ra - *rb; + let i_diff: R = *ia - *ib; + *ra = *ra + *rb; + *ia = *ia + *ib; + *rb = r_diff * omg_re - i_diff * omg_im; + *ib = r_diff * omg_im + i_diff * omg_re; +} + +#[inline(always)] +fn inv_itwiddle(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) { + let r_diff: R = *ra - *rb; + let i_diff: R = *ia - *ib; + *ra = *ra + *rb; + *ia = *ia + *ib; + *rb = r_diff * omg_im + i_diff * omg_re; + *ib = -r_diff * omg_re + i_diff * omg_im; +} + +#[inline(always)] +fn ifft2_ref(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) { + let [ra, rb] = re; + let [ia, ib] = im; + let [romg, iomg] = omg; + inv_twiddle(ra, ia, rb, ib, romg, iomg); +} + +#[inline(always)] +fn ifft4_ref(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) { + let [re_0, re_1, re_2, re_3] = re; + let [im_0, im_1, im_2, im_3] = im; + + { + let omg_0: R = omg[0]; + let omg_1: R = omg[1]; + + inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1); + inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_1); + } + + { + let omg_0: R = omg[2]; + let omg_1: R = omg[3]; + inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1); + inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1); + } +} + +#[inline(always)] +fn ifft8_ref(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) { + let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re; + let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im; + + { + let omg_4: R = omg[0]; + let omg_5: R = omg[1]; + let omg_6: R = omg[2]; + let omg_7: R = omg[3]; + inv_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6); + inv_itwiddle(re_2, im_2, re_3, im_3, omg_4, omg_6); + inv_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7); + inv_itwiddle(re_6, im_6, re_7, im_7, omg_5, omg_7); + } + + { + let omg_2: R = omg[4]; + let omg_3: R = omg[5]; + inv_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3); + inv_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3); + inv_itwiddle(re_4, im_4, re_6, im_6, omg_2, omg_3); + inv_itwiddle(re_5, im_5, re_7, im_7, omg_2, omg_3); + } + + { + let omg_0: R = omg[6]; + let omg_1: R = omg[7]; + inv_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1); + inv_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1); + inv_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1); + inv_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1); + } +} + +#[inline(always)] +fn ifft16_ref(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) { + let [ + re_0, + re_1, + re_2, + re_3, + re_4, + re_5, + re_6, + re_7, + re_8, + re_9, + re_10, + re_11, + re_12, + re_13, + re_14, + re_15, + ] = re; + let [ + im_0, + im_1, + im_2, + im_3, + im_4, + im_5, + im_6, + im_7, + im_8, + im_9, + im_10, + im_11, + im_12, + im_13, + im_14, + im_15, + ] = im; + + { + let omg_0: R = omg[0]; + let omg_1: R = omg[1]; + let omg_2: R = omg[2]; + let omg_3: R = omg[3]; + let omg_4: R = omg[4]; + let omg_5: R = omg[5]; + let omg_6: R = omg[6]; + let omg_7: R = omg[7]; + inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4); + inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_4); + inv_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5); + inv_itwiddle(re_6, im_6, re_7, im_7, omg_1, omg_5); + inv_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6); + inv_itwiddle(re_10, im_10, re_11, im_11, omg_2, omg_6); + inv_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7); + inv_itwiddle(re_14, im_14, re_15, im_15, omg_3, omg_7); + } + + { + let omg_0: R = omg[8]; + let omg_1: R = omg[9]; + let omg_2: R = omg[10]; + let omg_3: R = omg[11]; + inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1); + inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1); + inv_itwiddle(re_4, im_4, re_6, im_6, omg_0, omg_1); + inv_itwiddle(re_5, im_5, re_7, im_7, omg_0, omg_1); + inv_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3); + inv_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3); + inv_itwiddle(re_12, im_12, re_14, im_14, omg_2, omg_3); + inv_itwiddle(re_13, im_13, re_15, im_15, omg_2, omg_3); + } + + { + let omg_2: R = omg[12]; + let omg_3: R = omg[13]; + inv_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3); + inv_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3); + inv_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3); + inv_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3); + inv_itwiddle(re_8, im_8, re_12, im_12, omg_2, omg_3); + inv_itwiddle(re_9, im_9, re_13, im_13, omg_2, omg_3); + inv_itwiddle(re_10, im_10, re_14, im_14, omg_2, omg_3); + inv_itwiddle(re_11, im_11, re_15, im_15, omg_2, omg_3); + } + + { + let omg_0: R = omg[14]; + let omg_1: R = omg[15]; + inv_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1); + inv_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1); + inv_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1); + inv_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1); + inv_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1); + inv_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1); + inv_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1); + inv_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1); + } +} + +#[inline(always)] +fn inv_twiddle_ifft_ref(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) { + let romg = omg[0]; + let iomg = omg[1]; + + let (re_lhs, re_rhs) = re.split_at_mut(h); + let (im_lhs, im_rhs) = im.split_at_mut(h); + + for i in 0..h { + inv_twiddle( + &mut re_lhs[i], + &mut im_lhs[i], + &mut re_rhs[i], + &mut im_rhs[i], + romg, + iomg, + ); + } +} + +#[inline(always)] +fn inv_bitwiddle_ifft_ref(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) { + let (r0, r2) = re.split_at_mut(2 * h); + let (r0, r1) = r0.split_at_mut(h); + let (r2, r3) = r2.split_at_mut(h); + + let (i0, i2) = im.split_at_mut(2 * h); + let (i0, i1) = i0.split_at_mut(h); + let (i2, i3) = i2.split_at_mut(h); + + let omg_0: R = omg[0]; + let omg_1: R = omg[1]; + let omg_2: R = omg[2]; + let omg_3: R = omg[3]; + + for i in 0..h { + inv_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_0, omg_1); + inv_itwiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_0, omg_1); + } + + for i in 0..h { + inv_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_2, omg_3); + inv_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_2, omg_3); + } +} diff --git a/poulpy-hal/src/reference/fft64/reim/mod.rs b/poulpy-hal/src/reference/fft64/reim/mod.rs new file mode 100644 index 0000000..28c90b1 --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim/mod.rs @@ -0,0 +1,128 @@ +// ---------------------------------------------------------------------- +// DISCLAIMER +// +// This module contains code that has been directly ported from the +// spqlios-arithmetic library +// (https://github.com/tfhe/spqlios-arithmetic), which is licensed +// under the Apache License, Version 2.0. +// +// The porting process from C to Rust was done with minimal changes +// in order to preserve the semantics and performance characteristics +// of the original implementation. +// +// Both Poulpy and spqlios-arithmetic are distributed under the terms +// of the Apache License, Version 2.0. See the LICENSE file for details. +// +// ---------------------------------------------------------------------- + +#![allow(bad_asm_style)] + +mod conversion; +mod fft_ref; +mod fft_vec; +mod ifft_ref; +mod table_fft; +mod table_ifft; +mod zero; + +pub use conversion::*; +pub use fft_ref::*; +pub use fft_vec::*; +pub use ifft_ref::*; +pub use table_fft::*; +pub use table_ifft::*; +pub use zero::*; + +#[inline(always)] +pub(crate) fn as_arr(x: &[R]) -> &[R; size] { + debug_assert!(x.len() >= size); + unsafe { &*(x.as_ptr() as *const [R; size]) } +} + +#[inline(always)] +pub(crate) fn as_arr_mut(x: &mut [R]) -> &mut [R; size] { + debug_assert!(x.len() >= size); + unsafe { &mut *(x.as_mut_ptr() as *mut [R; size]) } +} + +use rand_distr::num_traits::{Float, FloatConst}; +#[inline(always)] +pub(crate) fn frac_rev_bits(x: usize) -> R { + let half: R = R::from(0.5).unwrap(); + + match x { + 0 => R::zero(), + 1 => half, + _ => { + if x.is_multiple_of(2) { + frac_rev_bits::(x >> 1) * half + } else { + frac_rev_bits::(x >> 1) * half + half + } + } + } +} + +pub trait ReimDFTExecute { + fn reim_dft_execute(table: &D, data: &mut [T]); +} + +pub trait ReimFromZnx { + fn reim_from_znx(res: &mut [f64], a: &[i64]); +} + +pub trait ReimToZnx { + fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]); +} + +pub trait ReimToZnxInplace { + fn reim_to_znx_inplace(res: &mut [f64], divisor: f64); +} + +pub trait ReimAdd { + fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]); +} + +pub trait ReimAddInplace { + fn reim_add_inplace(res: &mut [f64], a: &[f64]); +} + +pub trait ReimSub { + fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]); +} + +pub trait ReimSubABInplace { + fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]); +} + +pub trait ReimSubBAInplace { + fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]); +} + +pub trait ReimNegate { + fn reim_negate(res: &mut [f64], a: &[f64]); +} + +pub trait ReimNegateInplace { + fn reim_negate_inplace(res: &mut [f64]); +} + +pub trait ReimMul { + fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]); +} + +pub trait ReimMulInplace { + fn reim_mul_inplace(res: &mut [f64], a: &[f64]); +} + +pub trait ReimAddMul { + fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]); +} + +pub trait ReimCopy { + fn reim_copy(res: &mut [f64], a: &[f64]); +} + +pub trait ReimZero { + fn reim_zero(res: &mut [f64]); +} diff --git a/poulpy-hal/src/reference/fft64/reim/table_fft.rs b/poulpy-hal/src/reference/fft64/reim/table_fft.rs new file mode 100644 index 0000000..452678f --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim/table_fft.rs @@ -0,0 +1,207 @@ +use std::fmt::Debug; + +use rand_distr::num_traits::{Float, FloatConst}; + +use crate::{ + alloc_aligned, + reference::fft64::reim::{ReimDFTExecute, fft_ref, frac_rev_bits}, +}; + +pub struct ReimFFTRef; + +impl ReimDFTExecute, f64> for ReimFFTRef { + fn reim_dft_execute(table: &ReimFFTTable, data: &mut [f64]) { + fft_ref(table.m, &table.omg, data); + } +} + +pub struct ReimFFTTable { + m: usize, + omg: Vec, +} + +impl ReimFFTTable { + pub fn new(m: usize) -> Self { + assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m); + let mut omg: Vec = alloc_aligned::(2 * m); + + let quarter: R = R::from(1. / 4.).unwrap(); + + if m <= 16 { + match m { + 1 => {} + 2 => { + fill_fft2_omegas(quarter, &mut omg, 0); + } + 4 => { + fill_fft4_omegas(quarter, &mut omg, 0); + } + 8 => { + fill_fft8_omegas(quarter, &mut omg, 0); + } + 16 => { + fill_fft16_omegas(quarter, &mut omg, 0); + } + _ => {} + } + } else if m <= 2048 { + fill_fft_bfs_16_omegas(m, quarter, &mut omg, 0); + } else { + fill_fft_rec_16_omegas(m, quarter, &mut omg, 0); + } + + Self { m, omg } + } + + pub fn m(&self) -> usize { + self.m + } + + pub fn omg(&self) -> &[R] { + &self.omg + } +} + +#[inline(always)] +fn fill_fft2_omegas(j: R, omg: &mut [R], pos: usize) -> usize { + let omg_pos: &mut [R] = &mut omg[pos..]; + assert!(omg_pos.len() >= 2); + let angle: R = j / R::from(2).unwrap(); + let two_pi: R = R::from(2).unwrap() * R::PI(); + omg_pos[0] = R::cos(two_pi * angle); + omg_pos[1] = R::sin(two_pi * angle); + pos + 2 +} + +#[inline(always)] +fn fill_fft4_omegas(j: R, omg: &mut [R], pos: usize) -> usize { + let omg_pos: &mut [R] = &mut omg[pos..]; + assert!(omg_pos.len() >= 4); + let angle_1: R = j / R::from(2).unwrap(); + let angle_2: R = j / R::from(4).unwrap(); + let two_pi: R = R::from(2).unwrap() * R::PI(); + omg_pos[0] = R::cos(two_pi * angle_1); + omg_pos[1] = R::sin(two_pi * angle_1); + omg_pos[2] = R::cos(two_pi * angle_2); + omg_pos[3] = R::sin(two_pi * angle_2); + pos + 4 +} + +#[inline(always)] +fn fill_fft8_omegas(j: R, omg: &mut [R], pos: usize) -> usize { + let omg_pos: &mut [R] = &mut omg[pos..]; + assert!(omg_pos.len() >= 8); + let _8th: R = R::from(1. / 8.).unwrap(); + let angle_1: R = j / R::from(2).unwrap(); + let angle_2: R = j / R::from(4).unwrap(); + let angle_4: R = j / R::from(8).unwrap(); + let two_pi: R = R::from(2).unwrap() * R::PI(); + omg_pos[0] = R::cos(two_pi * angle_1); + omg_pos[1] = R::sin(two_pi * angle_1); + omg_pos[2] = R::cos(two_pi * angle_2); + omg_pos[3] = R::sin(two_pi * angle_2); + omg_pos[4] = R::cos(two_pi * angle_4); + omg_pos[5] = R::cos(two_pi * (angle_4 + _8th)); + omg_pos[6] = R::sin(two_pi * angle_4); + omg_pos[7] = R::sin(two_pi * (angle_4 + _8th)); + pos + 8 +} + +#[inline(always)] +fn fill_fft16_omegas(j: R, omg: &mut [R], pos: usize) -> usize { + let omg_pos: &mut [R] = &mut omg[pos..]; + assert!(omg_pos.len() >= 16); + let _8th: R = R::from(1. / 8.).unwrap(); + let _16th: R = R::from(1. / 16.).unwrap(); + let angle_1: R = j / R::from(2).unwrap(); + let angle_2: R = j / R::from(4).unwrap(); + let angle_4: R = j / R::from(8).unwrap(); + let angle_8: R = j / R::from(16).unwrap(); + let two_pi: R = R::from(2).unwrap() * R::PI(); + omg_pos[0] = R::cos(two_pi * angle_1); + omg_pos[1] = R::sin(two_pi * angle_1); + omg_pos[2] = R::cos(two_pi * angle_2); + omg_pos[3] = R::sin(two_pi * angle_2); + omg_pos[4] = R::cos(two_pi * angle_4); + omg_pos[5] = R::sin(two_pi * angle_4); + omg_pos[6] = R::cos(two_pi * (angle_4 + _8th)); + omg_pos[7] = R::sin(two_pi * (angle_4 + _8th)); + omg_pos[8] = R::cos(two_pi * angle_8); + omg_pos[9] = R::cos(two_pi * (angle_8 + _8th)); + omg_pos[10] = R::cos(two_pi * (angle_8 + _16th)); + omg_pos[11] = R::cos(two_pi * (angle_8 + _8th + _16th)); + omg_pos[12] = R::sin(two_pi * angle_8); + omg_pos[13] = R::sin(two_pi * (angle_8 + _8th)); + omg_pos[14] = R::sin(two_pi * (angle_8 + _16th)); + omg_pos[15] = R::sin(two_pi * (angle_8 + _8th + _16th)); + pos + 16 +} + +#[inline(always)] +fn fill_fft_bfs_16_omegas(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize { + let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize; + let mut mm: usize = m; + let mut jj: R = j; + + let two_pi: R = R::from(2).unwrap() * R::PI(); + + if !log_m.is_multiple_of(2) { + let h = mm >> 1; + let j: R = jj * R::from(0.5).unwrap(); + omg[pos] = R::cos(two_pi * j); + omg[pos + 1] = R::sin(two_pi * j); + pos += 2; + mm = h; + jj = j + } + + while mm > 16 { + let h: usize = mm >> 2; + let j: R = jj * R::from(1. / 4.).unwrap(); + for i in (0..m).step_by(mm) { + let rs_0 = j + frac_rev_bits::(i / mm) * R::from(1. / 4.).unwrap(); + let rs_1 = R::from(2).unwrap() * rs_0; + omg[pos] = R::cos(two_pi * rs_1); + omg[pos + 1] = R::sin(two_pi * rs_1); + omg[pos + 2] = R::cos(two_pi * rs_0); + omg[pos + 3] = R::sin(two_pi * rs_0); + pos += 4; + } + mm = h; + jj = j; + } + + for i in (0..m).step_by(16) { + let j = jj + frac_rev_bits(i >> 4); + fill_fft16_omegas(j, omg, pos); + pos += 16 + } + + pos +} + +#[inline(always)] +fn fill_fft_rec_16_omegas(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize { + if m <= 2048 { + return fill_fft_bfs_16_omegas(m, j, omg, pos); + } + let h: usize = m >> 1; + let s: R = j * R::from(0.5).unwrap(); + let _2pi = R::from(2).unwrap() * R::PI(); + omg[pos] = R::cos(_2pi * s); + omg[pos + 1] = R::sin(_2pi * s); + pos += 2; + pos = fill_fft_rec_16_omegas(h, s, omg, pos); + pos = fill_fft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos); + pos +} + +#[inline(always)] +fn ctwiddle_ref(ra: &mut f64, ia: &mut f64, rb: &mut f64, ib: &mut f64, omg_re: f64, omg_im: f64) { + let dr: f64 = *rb * omg_re - *ib * omg_im; + let di: f64 = *rb * omg_im + *ib * omg_re; + *rb = *ra - dr; + *ib = *ia - di; + *ra += dr; + *ia += di; +} diff --git a/poulpy-hal/src/reference/fft64/reim/table_ifft.rs b/poulpy-hal/src/reference/fft64/reim/table_ifft.rs new file mode 100644 index 0000000..929b933 --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim/table_ifft.rs @@ -0,0 +1,201 @@ +use std::fmt::Debug; + +use rand_distr::num_traits::{Float, FloatConst}; + +use crate::{ + alloc_aligned, + reference::fft64::reim::{ReimDFTExecute, frac_rev_bits, ifft_ref::ifft_ref}, +}; + +pub struct ReimIFFTRef; + +impl ReimDFTExecute, f64> for ReimIFFTRef { + fn reim_dft_execute(table: &ReimIFFTTable, data: &mut [f64]) { + ifft_ref(table.m, &table.omg, data); + } +} + +pub struct ReimIFFTTable { + m: usize, + omg: Vec, +} + +impl ReimIFFTTable { + pub fn new(m: usize) -> Self { + assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m); + let mut omg: Vec = alloc_aligned::(2 * m); + + let quarter: R = R::exp2(R::from(-2).unwrap()); + + if m <= 16 { + match m { + 1 => {} + 2 => { + fill_ifft2_omegas::(quarter, &mut omg, 0); + } + 4 => { + fill_ifft4_omegas(quarter, &mut omg, 0); + } + 8 => { + fill_ifft8_omegas(quarter, &mut omg, 0); + } + 16 => { + fill_ifft16_omegas(quarter, &mut omg, 0); + } + _ => {} + } + } else if m <= 2048 { + fill_ifft_bfs_16_omegas(m, quarter, &mut omg, 0); + } else { + fill_ifft_rec_16_omegas(m, quarter, &mut omg, 0); + } + + Self { m, omg } + } + + pub fn execute(&self, data: &mut [R]) { + ifft_ref(self.m, &self.omg, data); + } + + pub fn m(&self) -> usize { + self.m + } + + pub fn omg(&self) -> &[R] { + &self.omg + } +} + +#[inline(always)] +fn fill_ifft2_omegas(j: R, omg: &mut [R], pos: usize) -> usize { + let omg_pos: &mut [R] = &mut omg[pos..]; + assert!(omg_pos.len() >= 2); + let angle: R = j / R::exp2(R::from(2).unwrap()); + let two_pi: R = R::exp2(R::from(2).unwrap()) * R::PI(); + omg_pos[0] = R::cos(two_pi * angle); + omg_pos[1] = -R::sin(two_pi * angle); + pos + 2 +} + +#[inline(always)] +fn fill_ifft4_omegas(j: R, omg: &mut [R], pos: usize) -> usize { + let omg_pos: &mut [R] = &mut omg[pos..]; + assert!(omg_pos.len() >= 4); + let angle_1: R = j / R::from(2).unwrap(); + let angle_2: R = j / R::from(4).unwrap(); + let two_pi: R = R::from(2).unwrap() * R::PI(); + omg_pos[0] = R::cos(two_pi * angle_2); + omg_pos[1] = -R::sin(two_pi * angle_2); + omg_pos[2] = R::cos(two_pi * angle_1); + omg_pos[3] = -R::sin(two_pi * angle_1); + pos + 4 +} + +#[inline(always)] +fn fill_ifft8_omegas(j: R, omg: &mut [R], pos: usize) -> usize { + let omg_pos: &mut [R] = &mut omg[pos..]; + assert!(omg_pos.len() >= 8); + let _8th: R = R::from(1. / 8.).unwrap(); + let angle_1: R = j / R::from(2).unwrap(); + let angle_2: R = j / R::from(4).unwrap(); + let angle_4: R = j / R::from(2).unwrap(); + let two_pi: R = R::from(2).unwrap() * R::PI(); + omg_pos[0] = R::cos(two_pi * angle_4); + omg_pos[1] = R::cos(two_pi * (angle_4 + _8th)); + omg_pos[2] = -R::sin(two_pi * angle_4); + omg_pos[3] = -R::sin(two_pi * (angle_4 + _8th)); + omg_pos[4] = R::cos(two_pi * angle_2); + omg_pos[5] = -R::sin(two_pi * angle_2); + omg_pos[6] = R::cos(two_pi * angle_1); + omg_pos[7] = -R::sin(two_pi * angle_1); + pos + 8 +} + +#[inline(always)] +fn fill_ifft16_omegas(j: R, omg: &mut [R], pos: usize) -> usize { + let omg_pos: &mut [R] = &mut omg[pos..]; + assert!(omg_pos.len() >= 16); + let _8th: R = R::from(1. / 8.).unwrap(); + let _16th: R = R::from(1. / 16.).unwrap(); + let angle_1: R = j / R::from(2).unwrap(); + let angle_2: R = j / R::from(4).unwrap(); + let angle_4: R = j / R::from(8).unwrap(); + let angle_8: R = j / R::from(16).unwrap(); + let two_pi: R = R::from(2).unwrap() * R::PI(); + omg_pos[0] = R::cos(two_pi * angle_8); + omg_pos[1] = R::cos(two_pi * (angle_8 + _8th)); + omg_pos[2] = R::cos(two_pi * (angle_8 + _16th)); + omg_pos[3] = R::cos(two_pi * (angle_8 + _8th + _16th)); + omg_pos[4] = -R::sin(two_pi * angle_8); + omg_pos[5] = -R::sin(two_pi * (angle_8 + _8th)); + omg_pos[6] = -R::sin(two_pi * (angle_8 + _16th)); + omg_pos[7] = -R::sin(two_pi * (angle_8 + _8th + _16th)); + omg_pos[8] = R::cos(two_pi * angle_4); + omg_pos[9] = -R::sin(two_pi * angle_4); + omg_pos[10] = R::cos(two_pi * (angle_4 + _8th)); + omg_pos[11] = -R::sin(two_pi * (angle_4 + _8th)); + omg_pos[12] = R::cos(two_pi * angle_2); + omg_pos[13] = -R::sin(two_pi * angle_2); + omg_pos[14] = R::cos(two_pi * angle_1); + omg_pos[15] = -R::sin(two_pi * angle_1); + pos + 16 +} + +#[inline(always)] +fn fill_ifft_bfs_16_omegas(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize { + let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize; + let mut jj: R = j * R::from(16).unwrap() / R::from(m).unwrap(); + + for i in (0..m).step_by(16) { + let j = jj + frac_rev_bits(i >> 4); + fill_ifft16_omegas(j, omg, pos); + pos += 16 + } + + let mut h: usize = 16; + let m_half: usize = m >> 1; + + let two_pi: R = R::from(2).unwrap() * R::PI(); + + while h < m_half { + let mm: usize = h << 2; + for i in (0..m).step_by(mm) { + let rs_0 = jj + frac_rev_bits::(i / mm) / R::from(4).unwrap(); + let rs_1 = R::from(2).unwrap() * rs_0; + omg[pos] = R::cos(two_pi * rs_0); + omg[pos + 1] = -R::sin(two_pi * rs_0); + omg[pos + 2] = R::cos(two_pi * rs_1); + omg[pos + 3] = -R::sin(two_pi * rs_1); + pos += 4; + } + h = mm; + jj = jj * R::from(4).unwrap(); + } + + if !log_m.is_multiple_of(2) { + omg[pos] = R::cos(two_pi * jj); + omg[pos + 1] = -R::sin(two_pi * jj); + pos += 2; + jj = jj * R::from(2).unwrap(); + } + + assert_eq!(jj, j); + + pos +} + +#[inline(always)] +fn fill_ifft_rec_16_omegas(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize { + if m <= 2048 { + return fill_ifft_bfs_16_omegas(m, j, omg, pos); + } + let h: usize = m >> 1; + let s: R = j / R::from(2).unwrap(); + pos = fill_ifft_rec_16_omegas(h, s, omg, pos); + pos = fill_ifft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos); + let _2pi = R::from(2).unwrap() * R::PI(); + omg[pos] = R::cos(_2pi * s); + omg[pos + 1] = -R::sin(_2pi * s); + pos += 2; + pos +} diff --git a/poulpy-hal/src/reference/fft64/reim/zero.rs b/poulpy-hal/src/reference/fft64/reim/zero.rs new file mode 100644 index 0000000..3f52029 --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim/zero.rs @@ -0,0 +1,11 @@ +pub fn reim_zero_ref(res: &mut [f64]) { + res.fill(0.); +} + +pub fn reim_copy_ref(res: &mut [f64], a: &[f64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()) + } + res.copy_from_slice(a); +} diff --git a/poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs b/poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs new file mode 100644 index 0000000..49817f3 --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs @@ -0,0 +1,209 @@ +use crate::reference::fft64::reim::as_arr; + +#[inline(always)] +pub fn reim4_extract_1blk_from_reim_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + let mut offset: usize = blk << 2; + + debug_assert!(blk < (m >> 2)); + debug_assert!(dst.len() >= 2 * rows * 4); + + for chunk in dst.chunks_exact_mut(4).take(2 * rows) { + chunk.copy_from_slice(&src[offset..offset + 4]); + offset += m + } +} + +#[inline(always)] +pub fn reim4_save_1blk_to_reim_ref(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + let mut offset: usize = blk << 2; + + debug_assert!(blk < (m >> 2)); + debug_assert!(dst.len() >= offset + m + 4); + debug_assert!(src.len() >= 8); + + let dst_off = &mut dst[offset..offset + 4]; + + if OVERWRITE { + dst_off.copy_from_slice(&src[0..4]); + } else { + dst_off[0] += src[0]; + dst_off[1] += src[1]; + dst_off[2] += src[2]; + dst_off[3] += src[3]; + } + + offset += m; + + let dst_off = &mut dst[offset..offset + 4]; + if OVERWRITE { + dst_off.copy_from_slice(&src[4..8]); + } else { + dst_off[0] += src[4]; + dst_off[1] += src[5]; + dst_off[2] += src[6]; + dst_off[3] += src[7]; + } +} + +#[inline(always)] +pub fn reim4_save_2blk_to_reim_ref(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) { + let mut offset: usize = blk << 2; + + debug_assert!(blk < (m >> 2)); + debug_assert!(dst.len() >= offset + 3 * m + 4); + debug_assert!(src.len() >= 16); + + let dst_off = &mut dst[offset..offset + 4]; + if OVERWRITE { + dst_off.copy_from_slice(&src[0..4]); + } else { + dst_off[0] += src[0]; + dst_off[1] += src[1]; + dst_off[2] += src[2]; + dst_off[3] += src[3]; + } + + offset += m; + let dst_off = &mut dst[offset..offset + 4]; + if OVERWRITE { + dst_off.copy_from_slice(&src[4..8]); + } else { + dst_off[0] += src[4]; + dst_off[1] += src[5]; + dst_off[2] += src[6]; + dst_off[3] += src[7]; + } + + offset += m; + + let dst_off = &mut dst[offset..offset + 4]; + if OVERWRITE { + dst_off.copy_from_slice(&src[8..12]); + } else { + dst_off[0] += src[8]; + dst_off[1] += src[9]; + dst_off[2] += src[10]; + dst_off[3] += src[11]; + } + + offset += m; + let dst_off = &mut dst[offset..offset + 4]; + if OVERWRITE { + dst_off.copy_from_slice(&src[12..16]); + } else { + dst_off[0] += src[12]; + dst_off[1] += src[13]; + dst_off[2] += src[14]; + dst_off[3] += src[15]; + } +} + +#[inline(always)] +pub fn reim4_vec_mat1col_product_ref( + nrows: usize, + dst: &mut [f64], // 8 doubles: [re1(4), im1(4)] + u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row + v: &[f64], // nrows * 8 doubles: [ar(4) | ai(4)] per row +) { + #[cfg(debug_assertions)] + { + assert!(dst.len() >= 8, "dst must have at least 8 doubles"); + assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles"); + assert!(v.len() >= nrows * 8, "v must be at least nrows * 8 doubles"); + } + + println!("u_ref: {:?}", &u[..nrows * 8]); + println!("v_ref: {:?}", &v[..nrows * 8]); + + let mut acc: [f64; 8] = [0f64; 8]; + let mut j = 0; + for _ in 0..nrows { + reim4_add_mul(&mut acc, as_arr(&u[j..]), as_arr(&v[j..])); + j += 8; + } + dst[0..8].copy_from_slice(&acc); + + println!("dst_ref: {:?}", &dst[..8]); + println!(); +} + +#[inline(always)] +pub fn reim4_vec_mat2cols_product_ref( + nrows: usize, + dst: &mut [f64], // 16 doubles: [re1(4), im1(4), re2(4), im2(4)] + u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row + v: &[f64], // nrows * 16 doubles: [ar(4) | ai(4) | br(4) | bi(4)] per row +) { + #[cfg(debug_assertions)] + { + assert_eq!(dst.len(), 16, "dst must have 16 doubles"); + assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles"); + assert!( + v.len() >= nrows * 16, + "v must be at least nrows * 16 doubles" + ); + } + + // zero accumulators + let mut acc_0: [f64; 8] = [0f64; 8]; + let mut acc_1: [f64; 8] = [0f64; 8]; + for i in 0..nrows { + let _1j: usize = i << 3; + let _2j: usize = i << 4; + let u_j: &[f64; 8] = as_arr(&u[_1j..]); + reim4_add_mul(&mut acc_0, u_j, as_arr(&v[_2j..])); + reim4_add_mul(&mut acc_1, u_j, as_arr(&v[_2j + 8..])); + } + dst[0..8].copy_from_slice(&acc_0); + dst[8..16].copy_from_slice(&acc_1); +} + +#[inline(always)] +pub fn reim4_vec_mat2cols_2ndcol_product_ref( + nrows: usize, + dst: &mut [f64], // 8 doubles: [re1(4), im1(4), re2(4), im2(4)] + u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row + v: &[f64], // nrows * 16 doubles: [x | x | br(4) | bi(4)] per row +) { + #[cfg(debug_assertions)] + { + assert!( + dst.len() >= 8, + "dst must be at least 8 doubles but is {}", + dst.len() + ); + assert!( + u.len() >= nrows * 8, + "u must be at least nrows={} * 8 doubles but is {}", + nrows, + u.len() + ); + assert!( + v.len() >= nrows * 16, + "v must be at least nrows={} * 16 doubles but is {}", + nrows, + v.len() + ); + } + + // zero accumulators + let mut acc: [f64; 8] = [0f64; 8]; + for i in 0..nrows { + let _1j: usize = i << 3; + let _2j: usize = i << 4; + reim4_add_mul(&mut acc, as_arr(&u[_1j..]), as_arr(&v[_2j + 8..])); + } + dst[0..8].copy_from_slice(&acc); +} + +#[inline(always)] +pub fn reim4_add_mul(dst: &mut [f64; 8], a: &[f64; 8], b: &[f64; 8]) { + for k in 0..4 { + let ar: f64 = a[k]; + let br: f64 = b[k]; + let ai: f64 = a[k + 4]; + let bi: f64 = b[k + 4]; + dst[k] += ar * br - ai * bi; + dst[k + 4] += ar * bi + ai * br; + } +} diff --git a/poulpy-hal/src/reference/fft64/reim4/mod.rs b/poulpy-hal/src/reference/fft64/reim4/mod.rs new file mode 100644 index 0000000..04bcf9c --- /dev/null +++ b/poulpy-hal/src/reference/fft64/reim4/mod.rs @@ -0,0 +1,27 @@ +mod arithmetic_ref; + +pub use arithmetic_ref::*; + +pub trait Reim4Extract1Blk { + fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]); +} + +pub trait Reim4Save1Blk { + fn reim4_save_1blk(m: usize, blk: usize, dst: &mut [f64], src: &[f64]); +} + +pub trait Reim4Save2Blks { + fn reim4_save_2blks(m: usize, blk: usize, dst: &mut [f64], src: &[f64]); +} + +pub trait Reim4Mat1ColProd { + fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]); +} + +pub trait Reim4Mat2ColsProd { + fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]); +} + +pub trait Reim4Mat2Cols2ndColProd { + fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]); +} diff --git a/poulpy-hal/src/reference/fft64/svp.rs b/poulpy-hal/src/reference/fft64/svp.rs new file mode 100644 index 0000000..f91901a --- /dev/null +++ b/poulpy-hal/src/reference/fft64/svp.rs @@ -0,0 +1,119 @@ +use crate::{ + layouts::{ + Backend, ScalarZnx, ScalarZnxToRef, SvpPPol, SvpPPolToMut, SvpPPolToRef, VecZnx, VecZnxDft, VecZnxDftToMut, + VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut, + }, + reference::fft64::reim::{ReimAddMul, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimMul, ReimMulInplace, ReimZero}, +}; + +pub fn svp_prepare(table: &ReimFFTTable, res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ReimDFTExecute, f64> + ReimFromZnx, + R: SvpPPolToMut, + A: ScalarZnxToRef, +{ + let mut res: SvpPPol<&mut [u8], BE> = res.to_mut(); + let a: ScalarZnx<&[u8]> = a.to_ref(); + BE::reim_from_znx(res.at_mut(res_col, 0), a.at(a_col, 0)); + BE::reim_dft_execute(table, res.at_mut(res_col, 0)); +} + +pub fn svp_apply_dft( + table: &ReimFFTTable, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, +) where + BE: Backend + ReimDFTExecute, f64> + ReimZero + ReimFromZnx + ReimMulInplace, + R: VecZnxDftToMut, + A: SvpPPolToRef, + B: VecZnxToRef, +{ + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: SvpPPol<&[u8], BE> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + + let res_size: usize = res.size(); + let b_size: usize = b.size(); + let min_size: usize = res_size.min(b_size); + + let ppol: &[f64] = a.at(a_col, 0); + for j in 0..min_size { + let out: &mut [f64] = res.at_mut(res_col, j); + BE::reim_from_znx(out, b.at(b_col, j)); + BE::reim_dft_execute(table, out); + BE::reim_mul_inplace(out, ppol); + } + + for j in min_size..res_size { + BE::reim_zero(res.at_mut(res_col, j)); + } +} + +pub fn svp_apply_dft_to_dft(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + BE: Backend + ReimMul + ReimZero, + R: VecZnxDftToMut, + A: SvpPPolToRef, + B: VecZnxDftToRef, +{ + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: SvpPPol<&[u8], BE> = a.to_ref(); + let b: VecZnxDft<&[u8], BE> = b.to_ref(); + + let res_size: usize = res.size(); + let b_size: usize = b.size(); + let min_size: usize = res_size.min(b_size); + + let ppol: &[f64] = a.at(a_col, 0); + for j in 0..min_size { + BE::reim_mul(res.at_mut(res_col, j), ppol, b.at(b_col, j)); + } + + for j in min_size..res_size { + BE::reim_zero(res.at_mut(res_col, j)); + } +} + +pub fn svp_apply_dft_to_dft_add(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + BE: Backend + ReimAddMul + ReimZero, + R: VecZnxDftToMut, + A: SvpPPolToRef, + B: VecZnxDftToRef, +{ + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: SvpPPol<&[u8], BE> = a.to_ref(); + let b: VecZnxDft<&[u8], BE> = b.to_ref(); + + let res_size: usize = res.size(); + let b_size: usize = b.size(); + let min_size: usize = res_size.min(b_size); + + let ppol: &[f64] = a.at(a_col, 0); + for j in 0..min_size { + BE::reim_addmul(res.at_mut(res_col, j), ppol, b.at(b_col, j)); + } + + for j in min_size..res_size { + BE::reim_zero(res.at_mut(res_col, j)); + } +} + +pub fn svp_apply_dft_to_dft_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ReimMulInplace, + R: VecZnxDftToMut, + A: SvpPPolToRef, +{ + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: SvpPPol<&[u8], BE> = a.to_ref(); + + let ppol: &[f64] = a.at(a_col, 0); + for j in 0..res.size() { + BE::reim_mul_inplace(res.at_mut(res_col, j), ppol); + } +} diff --git a/poulpy-hal/src/reference/fft64/vec_znx_big.rs b/poulpy-hal/src/reference/fft64/vec_znx_big.rs new file mode 100644 index 0000000..7b9ceb5 --- /dev/null +++ b/poulpy-hal/src/reference/fft64/vec_znx_big.rs @@ -0,0 +1,521 @@ +use std::f64::consts::SQRT_2; + +use crate::{ + api::VecZnxBigAddNormal, + layouts::{ + Backend, Module, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, ZnxView, ZnxViewMut, + }, + oep::VecZnxBigAllocBytesImpl, + reference::{ + vec_znx::{ + vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate, + vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, + }, + znx::{ + ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep, + ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, + ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref, + }, + }, + source::Source, +}; + +pub fn vec_znx_big_add(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + BE: Backend + ZnxAdd + ZnxCopy + ZnxZero, + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let a: VecZnxBig<&[u8], BE> = a.to_ref(); + let b: VecZnxBig<&[u8], BE> = b.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + let b_vznx: VecZnx<&[u8]> = VecZnx { + data: b.data, + n: b.n, + cols: b.cols, + size: b.size, + max_size: b.max_size, + }; + + vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col); +} + +pub fn vec_znx_big_add_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ZnxAddInplace, + R: VecZnxBigToMut, + A: VecZnxBigToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let a: VecZnxBig<&[u8], BE> = a.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col); +} + +pub fn vec_znx_big_add_small(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + BE: Backend + ZnxAdd + ZnxCopy + ZnxZero, + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let a: VecZnxBig<&[u8], BE> = a.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col); +} + +pub fn vec_znx_big_add_small_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ZnxAddInplace, + R: VecZnxBigToMut, + A: VecZnxToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col); +} + +pub fn vec_znx_big_automorphism_inplace_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_big_automorphism(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ZnxAutomorphism + ZnxZero, + R: VecZnxBigToMut, + A: VecZnxBigToRef, +{ + let res: VecZnxBig<&mut [u8], _> = res.to_mut(); + let a: VecZnxBig<&[u8], _> = a.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + vec_znx_automorphism::<_, _, BE>(p, &mut res_vznx, res_col, &a_vznx, a_col); +} + +pub fn vec_znx_big_automorphism_inplace(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64]) +where + BE: Backend + ZnxAutomorphism + ZnxCopy, + R: VecZnxBigToMut, +{ + let res: VecZnxBig<&mut [u8], _> = res.to_mut(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + vec_znx_automorphism_inplace::<_, BE>(p, &mut res_vznx, res_col, tmp); +} + +pub fn vec_znx_big_negate(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ZnxNegate + ZnxZero, + R: VecZnxBigToMut, + A: VecZnxBigToRef, +{ + let res: VecZnxBig<&mut [u8], _> = res.to_mut(); + let a: VecZnxBig<&[u8], _> = a.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + vec_znx_negate::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col); +} + +pub fn vec_znx_big_negate_inplace(res: &mut R, res_col: usize) +where + BE: Backend + ZnxNegateInplace, + R: VecZnxBigToMut, +{ + let res: VecZnxBig<&mut [u8], _> = res.to_mut(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + vec_znx_negate_inplace::<_, BE>(&mut res_vznx, res_col); +} + +pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_big_normalize(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) +where + R: VecZnxToMut, + A: VecZnxBigToRef, + BE: Backend + + ZnxNormalizeFirstStepCarryOnly + + ZnxNormalizeMiddleStepCarryOnly + + ZnxNormalizeMiddleStep + + ZnxNormalizeFinalStep + + ZnxNormalizeFirstStep + + ZnxZero, +{ + let a: VecZnxBig<&[u8], _> = a.to_ref(); + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry); +} + +pub fn vec_znx_big_add_normal_ref>( + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + sigma: f64, + bound: f64, + source: &mut Source, +) where + R: VecZnxBigToMut, +{ + let mut res: VecZnxBig<&mut [u8], B> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + znx_add_normal_f64_ref( + res.at_mut(res_col, limb), + sigma * scale, + bound * scale, + source, + ) +} + +pub fn test_vec_znx_big_add_normal(module: &Module) +where + Module: VecZnxBigAddNormal, + B: Backend + VecZnxBigAllocBytesImpl, +{ + let n: usize = module.n(); + let basek: usize = 17; + let k: usize = 2 * 17; + let size: usize = 5; + let sigma: f64 = 3.2; + let bound: f64 = 6.0 * sigma; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let k_f64: f64 = (1u64 << k as u64) as f64; + let sqrt2: f64 = SQRT_2; + (0..cols).for_each(|col_i| { + let mut a: VecZnxBig, B> = VecZnxBig::alloc(n, cols, size); + module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); + module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..size).for_each(|limb_i| { + assert_eq!(a.at(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(basek, col_i) * k_f64; + assert!( + (std - sigma * sqrt2).abs() < 0.1, + "std={} ~!= {}", + std, + sigma * sqrt2 + ); + } + }) + }); +} + +/// R <- A - B +pub fn vec_znx_big_sub(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + BE: Backend + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy, + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let a: VecZnxBig<&[u8], BE> = a.to_ref(); + let b: VecZnxBig<&[u8], BE> = b.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + let b_vznx: VecZnx<&[u8]> = VecZnx { + data: b.data, + n: b.n, + cols: b.cols, + size: b.size, + max_size: b.max_size, + }; + + vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col); +} + +/// R <- A - B +pub fn vec_znx_big_sub_ab_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ZnxSubABInplace, + R: VecZnxBigToMut, + A: VecZnxBigToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let a: VecZnxBig<&[u8], BE> = a.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col); +} + +/// R <- B - A +pub fn vec_znx_big_sub_ba_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ZnxSubBAInplace + ZnxNegateInplace, + R: VecZnxBigToMut, + A: VecZnxBigToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let a: VecZnxBig<&[u8], BE> = a.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col); +} + +/// R <- A - B +pub fn vec_znx_big_sub_small_a(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + BE: Backend + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy, + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let b: VecZnxBig<&[u8], BE> = b.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let b_vznx: VecZnx<&[u8]> = VecZnx { + data: b.data, + n: b.n, + cols: b.cols, + size: b.size, + max_size: b.max_size, + }; + + vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, a, a_col, &b_vznx, b_col); +} + +/// R <- A - B +pub fn vec_znx_big_sub_small_b(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + BE: Backend + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy, + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let a: VecZnxBig<&[u8], BE> = a.to_ref(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + let a_vznx: VecZnx<&[u8]> = VecZnx { + data: a.data, + n: a.n, + cols: a.cols, + size: a.size, + max_size: a.max_size, + }; + + vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col); +} + +/// R <- R - A +pub fn vec_znx_big_sub_small_a_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ZnxSubABInplace, + R: VecZnxBigToMut, + A: VecZnxToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col); +} + +/// R <- A - R +pub fn vec_znx_big_sub_small_b_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ZnxSubBAInplace + ZnxNegateInplace, + R: VecZnxBigToMut, + A: VecZnxToRef, +{ + let res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + + let mut res_vznx: VecZnx<&mut [u8]> = VecZnx { + data: res.data, + n: res.n, + cols: res.cols, + size: res.size, + max_size: res.max_size, + }; + + vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col); +} diff --git a/poulpy-hal/src/reference/fft64/vec_znx_dft.rs b/poulpy-hal/src/reference/fft64/vec_znx_dft.rs new file mode 100644 index 0000000..5abf0da --- /dev/null +++ b/poulpy-hal/src/reference/fft64/vec_znx_dft.rs @@ -0,0 +1,369 @@ +use bytemuck::cast_slice_mut; + +use crate::{ + layouts::{ + Backend, Data, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos, + ZnxView, ZnxViewMut, + }, + reference::{ + fft64::reim::{ + ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate, + ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero, + }, + znx::ZnxZero, + }, +}; + +pub fn vec_znx_dft_add(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + BE: Backend + ReimAdd + ReimCopy + ReimZero, + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: VecZnxDftToRef, +{ + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + let b: VecZnxDft<&[u8], BE> = b.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + let b_size: usize = b.size(); + + if a_size <= b_size { + let sum_size: usize = a_size.min(res_size); + let cpy_size: usize = b_size.min(res_size); + + for j in 0..sum_size { + BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); + } + + for j in sum_size..cpy_size { + BE::reim_copy(res.at_mut(res_col, j), b.at(b_col, j)); + } + + for j in cpy_size..res_size { + BE::reim_zero(res.at_mut(res_col, j)); + } + } else { + let sum_size: usize = b_size.min(res_size); + let cpy_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); + } + + for j in sum_size..cpy_size { + BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in cpy_size..res_size { + BE::reim_zero(res.at_mut(res_col, j)); + } + } +} + +pub fn vec_znx_dft_add_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ReimAddInplace, + R: VecZnxDftToMut, + A: VecZnxDftToRef, +{ + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let sum_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + } +} + +pub fn vec_znx_dft_copy(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ReimCopy + ReimZero, + R: VecZnxDftToMut, + A: VecZnxDftToRef, +{ + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()) + } + + let steps: usize = a.size().div_ceil(step); + let min_steps: usize = res.size().min(steps); + + (0..min_steps).for_each(|j| { + let limb: usize = offset + j * step; + if limb < a.size() { + BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, limb)); + } + }); + (min_steps..res.size()).for_each(|j| { + BE::reim_zero(res.at_mut(res_col, j)); + }) +} + +pub fn vec_znx_dft_apply( + table: &ReimFFTTable, + step: usize, + offset: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, +) where + BE: Backend + ReimDFTExecute, f64> + ReimFromZnx + ReimZero, + R: VecZnxDftToMut, + A: VecZnxToRef, +{ + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert!(step > 0); + assert_eq!(table.m() << 1, res.n()); + assert_eq!(a.n(), res.n()); + } + + let a_size: usize = a.size(); + let res_size: usize = res.size(); + + let steps: usize = a_size.div_ceil(step); + let min_steps: usize = res_size.min(steps); + + for j in 0..min_steps { + let limb = offset + j * step; + if limb < a_size { + BE::reim_from_znx(res.at_mut(res_col, j), a.at(a_col, limb)); + BE::reim_dft_execute(table, res.at_mut(res_col, j)); + } + } + + (min_steps..res.size()).for_each(|j| { + BE::reim_zero(res.at_mut(res_col, j)); + }); +} + +pub fn vec_znx_idft_apply(table: &ReimIFFTTable, res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + + ReimDFTExecute, f64> + + ReimCopy + + ReimToZnxInplace + + ZnxZero, + R: VecZnxBigToMut, + A: VecZnxDftToRef, +{ + let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(table.m() << 1, res.n()); + assert_eq!(a.n(), res.n()); + } + + let res_size: usize = res.size(); + let min_size: usize = res_size.min(a.size()); + + let divisor: f64 = table.m() as f64; + + for j in 0..min_size { + let res_slice_f64: &mut [f64] = cast_slice_mut(res.at_mut(res_col, j)); + BE::reim_copy(res_slice_f64, a.at(a_col, j)); + BE::reim_dft_execute(table, res_slice_f64); + BE::reim_to_znx_inplace(res_slice_f64, divisor); + } + + for j in min_size..res_size { + BE::znx_zero(res.at_mut(res_col, j)); + } +} + +pub fn vec_znx_idft_apply_tmpa(table: &ReimIFFTTable, res: &mut R, res_col: usize, a: &mut A, a_col: usize) +where + BE: Backend + ReimDFTExecute, f64> + ReimToZnx + ZnxZero, + R: VecZnxBigToMut, + A: VecZnxDftToMut, +{ + let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut(); + let mut a: VecZnxDft<&mut [u8], BE> = a.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(table.m() << 1, res.n()); + assert_eq!(a.n(), res.n()); + } + + let res_size = res.size(); + let min_size: usize = res_size.min(a.size()); + + let divisor: f64 = table.m() as f64; + + for j in 0..min_size { + BE::reim_dft_execute(table, a.at_mut(a_col, j)); + BE::reim_to_znx(res.at_mut(res_col, j), divisor, a.at(a_col, j)); + } + + for j in min_size..res_size { + BE::znx_zero(res.at_mut(res_col, j)); + } +} + +pub fn vec_znx_idft_apply_consume(table: &ReimIFFTTable, mut res: VecZnxDft) -> VecZnxBig +where + BE: Backend + ReimDFTExecute, f64> + ReimToZnxInplace, + VecZnxDft: VecZnxDftToMut, +{ + { + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(table.m() << 1, res.n()); + } + + let divisor: f64 = table.m() as f64; + + for i in 0..res.cols() { + for j in 0..res.size() { + BE::reim_dft_execute(table, res.at_mut(i, j)); + BE::reim_to_znx_inplace(res.at_mut(i, j), divisor); + } + } + } + + res.into_big() +} + +pub fn vec_znx_dft_sub(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + BE: Backend + ReimSub + ReimNegate + ReimZero + ReimCopy, + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: VecZnxDftToRef, +{ + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + let b: VecZnxDft<&[u8], BE> = b.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + let b_size: usize = b.size(); + + if a_size <= b_size { + let sum_size: usize = a_size.min(res_size); + let cpy_size: usize = b_size.min(res_size); + + for j in 0..sum_size { + BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); + } + + for j in sum_size..cpy_size { + BE::reim_negate(res.at_mut(res_col, j), b.at(b_col, j)); + } + + for j in cpy_size..res_size { + BE::reim_zero(res.at_mut(res_col, j)); + } + } else { + let sum_size: usize = b_size.min(res_size); + let cpy_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); + } + + for j in sum_size..cpy_size { + BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in cpy_size..res_size { + BE::reim_zero(res.at_mut(res_col, j)); + } + } +} + +pub fn vec_znx_dft_sub_ab_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ReimSubABInplace, + R: VecZnxDftToMut, + A: VecZnxDftToRef, +{ + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let sum_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + } +} + +pub fn vec_znx_dft_sub_ba_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + BE: Backend + ReimSubBAInplace + ReimNegateInplace, + R: VecZnxDftToMut, + A: VecZnxDftToRef, +{ + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let sum_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in sum_size..res_size { + BE::reim_negate_inplace(res.at_mut(res_col, j)); + } +} + +pub fn vec_znx_dft_zero(res: &mut R) +where + R: VecZnxDftToMut, + BE: Backend + ReimZero, +{ + BE::reim_zero(res.to_mut().raw_mut()); +} diff --git a/poulpy-hal/src/reference/fft64/vmp.rs b/poulpy-hal/src/reference/fft64/vmp.rs new file mode 100644 index 0000000..f6fb73c --- /dev/null +++ b/poulpy-hal/src/reference/fft64/vmp.rs @@ -0,0 +1,365 @@ +use crate::{ + cast_mut, + layouts::{MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut}, + oep::VecZnxDftAllocBytesImpl, + reference::fft64::{ + reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero}, + reim4::{Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks}, + vec_znx_dft::vec_znx_dft_apply, + }, +}; + +use crate::layouts::{Backend, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatToRef, ZnxInfos}; + +pub fn vmp_prepare_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vmp_prepare(table: &ReimFFTTable, pmat: &mut R, mat: &A, tmp: &mut [f64]) +where + BE: Backend + ReimDFTExecute, f64> + ReimFromZnx + Reim4Extract1Blk, + R: VmpPMatToMut, + A: MatZnxToRef, +{ + let mut res: crate::layouts::VmpPMat<&mut [u8], BE> = pmat.to_mut(); + let a: MatZnx<&[u8]> = mat.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + assert_eq!( + res.cols_in(), + a.cols_in(), + "res.cols_in: {} != a.cols_in: {}", + res.cols_in(), + a.cols_in() + ); + assert_eq!( + res.rows(), + a.rows(), + "res.rows: {} != a.rows: {}", + res.rows(), + a.rows() + ); + assert_eq!( + res.cols_out(), + a.cols_out(), + "res.cols_out: {} != a.cols_out: {}", + res.cols_out(), + a.cols_out() + ); + assert_eq!( + res.size(), + a.size(), + "res.size: {} != a.size: {}", + res.size(), + a.size() + ); + } + + let nrows: usize = a.cols_in() * a.rows(); + let ncols: usize = a.cols_out() * a.size(); + vmp_prepare_core::(table, res.raw_mut(), a.raw(), nrows, ncols, tmp); +} + +pub(crate) fn vmp_prepare_core( + table: &ReimFFTTable, + pmat: &mut [f64], + mat: &[i64], + nrows: usize, + ncols: usize, + tmp: &mut [f64], +) where + REIM: ReimDFTExecute, f64> + ReimFromZnx + Reim4Extract1Blk, +{ + let m: usize = table.m(); + let n: usize = m << 1; + + #[cfg(debug_assertions)] + { + assert!(n >= 8); + assert_eq!(mat.len(), n * nrows * ncols); + assert_eq!(pmat.len(), n * nrows * ncols); + assert_eq!(tmp.len(), vmp_prepare_tmp_bytes(n) / size_of::()) + } + + let offset: usize = nrows * ncols * 8; + + for row_i in 0..nrows { + for col_i in 0..ncols { + let pos: usize = n * (row_i * ncols + col_i); + + REIM::reim_from_znx(tmp, &mat[pos..pos + n]); + REIM::reim_dft_execute(table, tmp); + + let dst: &mut [f64] = if col_i == (ncols - 1) && !ncols.is_multiple_of(2) { + &mut pmat[col_i * nrows * 8 + row_i * 8..] + } else { + &mut pmat[(col_i / 2) * (nrows * 16) + row_i * 16 + (col_i % 2) * 8..] + }; + + for blk_i in 0..m >> 2 { + REIM::reim4_extract_1blk(m, 1, blk_i, &mut dst[blk_i * offset..], tmp); + } + } + } +} + +pub fn vmp_apply_dft_tmp_bytes(n: usize, a_size: usize, prows: usize, pcols_in: usize) -> usize { + let row_max: usize = (a_size).min(prows); + (16 + (n + 8) * row_max * pcols_in) * size_of::() +} + +pub fn vmp_apply_dft(table: &ReimFFTTable, res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64]) +where + BE: Backend + + VecZnxDftAllocBytesImpl + + ReimDFTExecute, f64> + + ReimZero + + Reim4Extract1Blk + + Reim4Mat1ColProd + + Reim4Mat2Cols2ndColProd + + Reim4Mat2ColsProd + + Reim4Save2Blks + + Reim4Save1Blk + + ReimFromZnx, + R: VecZnxDftToMut, + A: VecZnxToRef, + M: VmpPMatToRef, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let pmat: VmpPMat<&[u8], BE> = pmat.to_ref(); + + let n: usize = a.n(); + let cols: usize = pmat.cols_in(); + let size: usize = a.size().min(pmat.rows()); + + #[cfg(debug_assertions)] + { + assert!(tmp_bytes.len() >= vmp_apply_dft_tmp_bytes(n, size, pmat.rows(), cols)); + assert!(a.cols() <= cols); + } + + let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_alloc_bytes_impl(n, cols, size)); + + let mut a_dft: VecZnxDft<&mut [u8], BE> = VecZnxDft::from_data(cast_mut(data), n, cols, size); + + let offset: usize = cols - a.cols(); + for j in 0..cols { + vec_znx_dft_apply(table, 1, 0, &mut a_dft, j, &a, offset + j); + } + + vmp_apply_dft_to_dft(res, &a_dft, &pmat, tmp_bytes); +} + +pub fn vmp_apply_dft_to_dft_tmp_bytes(a_size: usize, prows: usize, pcols_in: usize) -> usize { + let row_max: usize = (a_size).min(prows); + (16 + 8 * row_max * pcols_in) * size_of::() +} + +pub fn vmp_apply_dft_to_dft(res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64]) +where + BE: Backend + + ReimZero + + Reim4Extract1Blk + + Reim4Mat1ColProd + + Reim4Mat2Cols2ndColProd + + Reim4Mat2ColsProd + + Reim4Save2Blks + + Reim4Save1Blk, + R: VecZnxDftToMut, + A: VecZnxDftToRef, + M: VmpPMatToRef, +{ + use crate::layouts::{ZnxView, ZnxViewMut}; + + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + let pmat: VmpPMat<&[u8], BE> = pmat.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), pmat.n()); + assert_eq!(a.n(), pmat.n()); + assert_eq!(res.cols(), pmat.cols_out()); + assert_eq!(a.cols(), pmat.cols_in()); + } + + let n: usize = res.n(); + let nrows: usize = pmat.cols_in() * pmat.rows(); + let ncols: usize = pmat.cols_out() * pmat.size(); + + let pmat_raw: &[f64] = pmat.raw(); + let a_raw: &[f64] = a.raw(); + let res_raw: &mut [f64] = res.raw_mut(); + + vmp_apply_dft_to_dft_core::(n, res_raw, a_raw, pmat_raw, 0, nrows, ncols, tmp_bytes) +} + +pub fn vmp_apply_dft_to_dft_add(res: &mut R, a: &A, pmat: &M, limb_offset: usize, tmp_bytes: &mut [f64]) +where + BE: Backend + + ReimZero + + Reim4Extract1Blk + + Reim4Mat1ColProd + + Reim4Mat2Cols2ndColProd + + Reim4Mat2ColsProd + + Reim4Save2Blks + + Reim4Save1Blk, + R: VecZnxDftToMut, + A: VecZnxDftToRef, + M: VmpPMatToRef, +{ + use crate::layouts::{ZnxView, ZnxViewMut}; + + let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut(); + let a: VecZnxDft<&[u8], BE> = a.to_ref(); + let pmat: VmpPMat<&[u8], BE> = pmat.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), pmat.n()); + assert_eq!(a.n(), pmat.n()); + assert_eq!(res.cols(), pmat.cols_out()); + assert_eq!(a.cols(), pmat.cols_in()); + } + + let n: usize = res.n(); + let nrows: usize = pmat.cols_in() * pmat.rows(); + let ncols: usize = pmat.cols_out() * pmat.size(); + + let pmat_raw: &[f64] = pmat.raw(); + let a_raw: &[f64] = a.raw(); + let res_raw: &mut [f64] = res.raw_mut(); + + vmp_apply_dft_to_dft_core::( + n, + res_raw, + a_raw, + pmat_raw, + limb_offset, + nrows, + ncols, + tmp_bytes, + ) +} + +#[allow(clippy::too_many_arguments)] +fn vmp_apply_dft_to_dft_core( + n: usize, + res: &mut [f64], + a: &[f64], + pmat: &[f64], + limb_offset: usize, + nrows: usize, + ncols: usize, + tmp_bytes: &mut [f64], +) where + REIM: ReimZero + + Reim4Extract1Blk + + Reim4Mat1ColProd + + Reim4Mat2Cols2ndColProd + + Reim4Mat2ColsProd + + Reim4Save2Blks + + Reim4Save1Blk, +{ + #[cfg(debug_assertions)] + { + assert!(n >= 8); + assert!(n.is_power_of_two()); + assert_eq!(pmat.len(), n * nrows * ncols); + assert!(res.len() & (n - 1) == 0); + assert!(a.len() & (n - 1) == 0); + } + + let a_size: usize = a.len() / n; + let res_size: usize = res.len() / n; + + let m: usize = n >> 1; + + let (mat2cols_output, extracted_blk) = tmp_bytes.split_at_mut(16); + + let row_max: usize = nrows.min(a_size); + let col_max: usize = ncols.min(res_size); + + if limb_offset >= col_max { + if OVERWRITE { + REIM::reim_zero(res); + } + return; + } + + for blk_i in 0..(m >> 2) { + let mat_blk_start: &[f64] = &pmat[blk_i * (8 * nrows * ncols)..]; + + REIM::reim4_extract_1blk(m, row_max, blk_i, extracted_blk, a); + + if limb_offset.is_multiple_of(2) { + for (col_res, col_pmat) in (0..).step_by(2).zip((limb_offset..col_max - 1).step_by(2)) { + let col_offset: usize = col_pmat * (8 * nrows); + REIM::reim4_mat2cols_prod( + row_max, + mat2cols_output, + extracted_blk, + &mat_blk_start[col_offset..], + ); + REIM::reim4_save_2blks::(m, blk_i, &mut res[col_res * n..], mat2cols_output); + } + } else { + let col_offset: usize = (limb_offset - 1) * (8 * nrows); + REIM::reim4_mat2cols_2ndcol_prod( + row_max, + mat2cols_output, + extracted_blk, + &mat_blk_start[col_offset..], + ); + + REIM::reim4_save_1blk::(m, blk_i, res, mat2cols_output); + + for (col_res, col_pmat) in (1..) + .step_by(2) + .zip((limb_offset + 1..col_max - 1).step_by(2)) + { + let col_offset: usize = col_pmat * (8 * nrows); + REIM::reim4_mat2cols_prod( + row_max, + mat2cols_output, + extracted_blk, + &mat_blk_start[col_offset..], + ); + REIM::reim4_save_2blks::(m, blk_i, &mut res[col_res * n..], mat2cols_output); + } + } + + if !col_max.is_multiple_of(2) { + let last_col: usize = col_max - 1; + let col_offset: usize = last_col * (8 * nrows); + + if last_col >= limb_offset { + if ncols == col_max { + REIM::reim4_mat1col_prod( + row_max, + mat2cols_output, + extracted_blk, + &mat_blk_start[col_offset..], + ); + } else { + REIM::reim4_mat2cols_prod( + row_max, + mat2cols_output, + extracted_blk, + &mat_blk_start[col_offset..], + ); + } + REIM::reim4_save_1blk::( + m, + blk_i, + &mut res[(last_col - limb_offset) * n..], + mat2cols_output, + ); + } + } + } + + REIM::reim_zero(&mut res[col_max * n..]); +} diff --git a/poulpy-hal/src/reference/mod.rs b/poulpy-hal/src/reference/mod.rs new file mode 100644 index 0000000..9fd1500 --- /dev/null +++ b/poulpy-hal/src/reference/mod.rs @@ -0,0 +1,4 @@ +pub mod fft64; +pub mod vec_znx; +pub mod zn; +pub mod znx; diff --git a/poulpy-hal/src/reference/vec_znx/add.rs b/poulpy-hal/src/reference/vec_znx/add.rs new file mode 100644 index 0000000..56c9eb9 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/add.rs @@ -0,0 +1,177 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; + +use crate::{ + api::{ModuleNew, VecZnxAdd, VecZnxAddInplace}, + layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero}, + source::Source, +}; + +pub fn vec_znx_add(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + ZNXARI: ZnxAdd + ZnxCopy + ZnxZero, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + let b_size: usize = b.size(); + + if a_size <= b_size { + let sum_size: usize = a_size.min(res_size); + let cpy_size: usize = b_size.min(res_size); + + for j in 0..sum_size { + ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); + } + + for j in sum_size..cpy_size { + ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j)); + } + + for j in cpy_size..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + } else { + let sum_size: usize = b_size.min(res_size); + let cpy_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); + } + + for j in sum_size..cpy_size { + ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in cpy_size..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + } +} + +pub fn vec_znx_add_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxAddInplace, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let sum_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + ZNXARI::znx_add_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + } +} + +pub fn bench_vec_znx_add(c: &mut Criterion, label: &str) +where + Module: VecZnxAdd + ModuleNew, +{ + let group_name: String = format!("vec_znx_add::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxAdd + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut b: VecZnx> = VecZnx::alloc(n, cols, size); + let mut c: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + b.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_add(&mut c, i, &a, i, &b, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_add_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxAddInplace + ModuleNew, +{ + let group_name: String = format!("vec_znx_add_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxAddInplace + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut b: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + b.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_add_inplace(&mut b, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/reference/vec_znx/add_scalar.rs b/poulpy-hal/src/reference/vec_znx/add_scalar.rs new file mode 100644 index 0000000..68f830e --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/add_scalar.rs @@ -0,0 +1,57 @@ +use crate::{ + layouts::{ScalarZnx, ScalarZnxToRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero}, +}; + +pub fn vec_znx_add_scalar(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize) +where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef, + ZNXARI: ZnxAdd + ZnxCopy + ZnxZero, +{ + let a: ScalarZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let min_size: usize = b.size().min(res.size()); + + #[cfg(debug_assertions)] + { + assert!( + b_limb < min_size, + "b_limb: {} > min_size: {}", + b_limb, + min_size + ); + } + + for j in 0..min_size { + if j == b_limb { + ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, 0), b.at(b_col, j)); + } else { + ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j)); + } + } + + for j in min_size..res.size() { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} + +pub fn vec_znx_add_scalar_inplace(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: ScalarZnxToRef, + ZNXARI: ZnxAddInplace, +{ + let a: ScalarZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert!(res_limb < res.size()); + } + + ZNXARI::znx_add_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0)); +} diff --git a/poulpy-hal/src/reference/vec_znx/automorphism.rs b/poulpy-hal/src/reference/vec_znx/automorphism.rs new file mode 100644 index 0000000..91069cb --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/automorphism.rs @@ -0,0 +1,150 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; + +use crate::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAutomorphism, VecZnxAutomorphismInplace, + VecZnxAutomorphismInplaceTmpBytes, + }, + layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ZnxAutomorphism, ZnxCopy, ZnxZero}, + source::Source, +}; + +pub fn vec_znx_automorphism_inplace_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_automorphism(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxAutomorphism + ZnxZero, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + use crate::layouts::ZnxInfos; + + assert_eq!(a.n(), res.n()); + } + + let min_size: usize = res.size().min(a.size()); + + for j in 0..min_size { + ZNXARI::znx_automorphism(p, res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in min_size..res.size() { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} + +pub fn vec_znx_automorphism_inplace(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64]) +where + R: VecZnxToMut, + ZNXARI: ZnxAutomorphism + ZnxCopy, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), tmp.len()); + } + for j in 0..res.size() { + ZNXARI::znx_automorphism(p, tmp, res.at(res_col, j)); + ZNXARI::znx_copy(res.at_mut(res_col, j), tmp); + } +} + +pub fn bench_vec_znx_automorphism(c: &mut Criterion, label: &str) +where + Module: VecZnxAutomorphism + ModuleNew, +{ + let group_name: String = format!("vec_znx_automorphism::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxAutomorphism + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + res.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_automorphism(-7, &mut res, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_automorphism_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxAutomorphismInplace + VecZnxAutomorphismInplaceTmpBytes + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_automorphism_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxAutomorphismInplace + ModuleNew + VecZnxAutomorphismInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut res: VecZnx> = VecZnx::alloc(n, cols, size); + + let mut scratch = ScratchOwned::alloc(module.vec_znx_automorphism_inplace_tmp_bytes()); + + // Fill a with random i64 + res.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_automorphism_inplace(-7, &mut res, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/reference/vec_znx/copy.rs b/poulpy-hal/src/reference/vec_znx/copy.rs new file mode 100644 index 0000000..ecc6a44 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/copy.rs @@ -0,0 +1,32 @@ +use crate::{ + layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ZnxCopy, ZnxZero}, +}; + +pub fn vec_znx_copy(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxCopy + ZnxZero, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()) + } + + let res_size = res.size(); + let a_size = a.size(); + + let min_size = res_size.min(a_size); + + for j in 0..min_size { + ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in min_size..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} diff --git a/poulpy-hal/src/reference/vec_znx/merge_rings.rs b/poulpy-hal/src/reference/vec_znx/merge_rings.rs new file mode 100644 index 0000000..0c9d204 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/merge_rings.rs @@ -0,0 +1,49 @@ +use crate::{ + layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}, + reference::{ + vec_znx::{vec_znx_rotate_inplace, vec_znx_switch_ring}, + znx::{ZnxCopy, ZnxRotate, ZnxSwitchRing, ZnxZero}, + }, +}; + +pub fn vec_znx_merge_rings_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_merge_rings(res: &mut R, res_col: usize, a: &[A], a_col: usize, tmp: &mut [i64]) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxRotate + ZnxZero, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_out, n_in) = (res.n(), a[0].to_ref().n()); + + #[cfg(debug_assertions)] + { + assert_eq!(tmp.len(), res.n()); + + debug_assert!( + n_out > n_in, + "invalid a: output ring degree should be greater" + ); + a[1..].iter().for_each(|ai| { + debug_assert_eq!( + ai.to_ref().n(), + n_in, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + assert!(n_out.is_multiple_of(n_in)); + assert_eq!(a.len(), n_out / n_in); + } + + a.iter().for_each(|ai| { + vec_znx_switch_ring::<_, _, ZNXARI>(&mut res, res_col, ai, a_col); + vec_znx_rotate_inplace::<_, ZNXARI>(-1, &mut res, res_col, tmp); + }); + + vec_znx_rotate_inplace::<_, ZNXARI>(a.len() as i64, &mut res, res_col, tmp); +} diff --git a/poulpy-hal/src/reference/vec_znx/mod.rs b/poulpy-hal/src/reference/vec_znx/mod.rs new file mode 100644 index 0000000..4edb574 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/mod.rs @@ -0,0 +1,31 @@ +mod add; +mod add_scalar; +mod automorphism; +mod copy; +mod merge_rings; +mod mul_xp_minus_one; +mod negate; +mod normalize; +mod rotate; +mod sampling; +mod shift; +mod split_ring; +mod sub; +mod sub_scalar; +mod switch_ring; + +pub use add::*; +pub use add_scalar::*; +pub use automorphism::*; +pub use copy::*; +pub use merge_rings::*; +pub use mul_xp_minus_one::*; +pub use negate::*; +pub use normalize::*; +pub use rotate::*; +pub use sampling::*; +pub use shift::*; +pub use split_ring::*; +pub use sub::*; +pub use sub_scalar::*; +pub use switch_ring::*; diff --git a/poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs b/poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs new file mode 100644 index 0000000..b07599d --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs @@ -0,0 +1,136 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; + +use crate::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, + VecZnxMulXpMinusOneInplaceTmpBytes, + }, + layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::{ + vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace}, + znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero}, + }, + source::Source, +}; + +pub fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_mul_xp_minus_one(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace, +{ + vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col); + vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col); +} + +pub fn vec_znx_mul_xp_minus_one_inplace(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64]) +where + R: VecZnxToMut, + ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), tmp.len()); + } + for j in 0..res.size() { + ZNXARI::znx_rotate(p, tmp, res.at(res_col, j)); + ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp); + } +} + +pub fn bench_vec_znx_mul_xp_minus_one(c: &mut Criterion, label: &str) +where + Module: VecZnxMulXpMinusOne + ModuleNew, +{ + let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxMulXpMinusOne + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + res.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_mul_xp_minus_one(-7, &mut res, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_mul_xp_minus_one_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxMulXpMinusOneInplace + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxMulXpMinusOneInplace + ModuleNew + VecZnxMulXpMinusOneInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut res: VecZnx> = VecZnx::alloc(n, cols, size); + + let mut scratch = ScratchOwned::alloc(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes()); + + // Fill a with random i64 + res.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_mul_xp_minus_one_inplace(-7, &mut res, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/reference/vec_znx/negate.rs b/poulpy-hal/src/reference/vec_znx/negate.rs new file mode 100644 index 0000000..f446467 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/negate.rs @@ -0,0 +1,131 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; + +use crate::{ + api::{ModuleNew, VecZnxNegate, VecZnxNegateInplace}, + layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ZnxNegate, ZnxNegateInplace, ZnxZero}, + source::Source, +}; + +pub fn vec_znx_negate(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxNegate + ZnxZero, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let min_size: usize = res.size().min(a.size()); + + for j in 0..min_size { + ZNXARI::znx_negate(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in min_size..res.size() { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} + +pub fn vec_znx_negate_inplace(res: &mut R, res_col: usize) +where + R: VecZnxToMut, + ZNXARI: ZnxNegateInplace, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + for j in 0..res.size() { + ZNXARI::znx_negate_inplace(res.at_mut(res_col, j)); + } +} + +pub fn bench_vec_znx_negate(c: &mut Criterion, label: &str) +where + Module: VecZnxNegate + ModuleNew, +{ + let group_name: String = format!("vec_znx_negate::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxNegate + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut b: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + b.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_negate(&mut b, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_negate_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxNegateInplace + ModuleNew, +{ + let group_name: String = format!("vec_znx_negate_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxNegateInplace + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + move || { + for i in 0..cols { + module.vec_znx_negate_inplace(&mut a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/reference/vec_znx/normalize.rs b/poulpy-hal/src/reference/vec_znx/normalize.rs new file mode 100644 index 0000000..98795b8 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/normalize.rs @@ -0,0 +1,193 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; + +use crate::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ + ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, + ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, + ZnxZero, + }, + source::Source, +}; + +pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_normalize(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxZero + + ZnxNormalizeFirstStepCarryOnly + + ZnxNormalizeMiddleStepCarryOnly + + ZnxNormalizeMiddleStep + + ZnxNormalizeFinalStep + + ZnxNormalizeFirstStep, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert!(carry.len() >= res.n()); + } + + let res_size: usize = res.size(); + let a_size = a.size(); + + if a_size > res_size { + for j in (res_size..a_size).rev() { + if j == a_size - 1 { + ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry); + } + } + + for j in (1..res_size).rev() { + ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + } + + ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry); + } else { + for j in (0..a_size).rev() { + if j == a_size - 1 { + ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + } else if j == 0 { + ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry); + } + } + + for j in a_size..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + } +} + +pub fn vec_znx_normalize_inplace(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64]) +where + ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert!(carry.len() >= res.n()); + } + + let res_size: usize = res.size(); + + for j in (0..res_size).rev() { + if j == res_size - 1 { + ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry); + } else if j == 0 { + ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry); + } + } +} + +pub fn bench_vec_znx_normalize(c: &mut Criterion, label: &str) +where + Module: VecZnxNormalize + ModuleNew + VecZnxNormalizeTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_normalize::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxNormalize + ModuleNew + VecZnxNormalizeTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let basek: usize = 50; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + res.fill_uniform(50, &mut source); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes()); + + move || { + for i in 0..cols { + module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_normalize_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxNormalizeInplace + ModuleNew + VecZnxNormalizeTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_normalize_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxNormalizeInplace + ModuleNew + VecZnxNormalizeTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let basek: usize = 50; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes()); + + move || { + for i in 0..cols { + module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/reference/vec_znx/rotate.rs b/poulpy-hal/src/reference/vec_znx/rotate.rs new file mode 100644 index 0000000..78ef17c --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/rotate.rs @@ -0,0 +1,148 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; + +use crate::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes}, + layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ZnxCopy, ZnxRotate, ZnxZero}, + source::Source, +}; + +pub fn vec_znx_rotate_inplace_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_rotate(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxRotate + ZnxZero, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), a.n()) + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let min_size: usize = res_size.min(a_size); + + for j in 0..min_size { + ZNXARI::znx_rotate(p, res.at_mut(res_col, j), a.at(a_col, j)) + } + + for j in min_size..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} + +pub fn vec_znx_rotate_inplace(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64]) +where + R: VecZnxToMut, + ZNXARI: ZnxRotate + ZnxCopy, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), tmp.len()); + } + for j in 0..res.size() { + ZNXARI::znx_rotate(p, tmp, res.at(res_col, j)); + ZNXARI::znx_copy(res.at_mut(res_col, j), tmp); + } +} + +pub fn bench_vec_znx_rotate(c: &mut Criterion, label: &str) +where + Module: VecZnxRotate + ModuleNew, +{ + let group_name: String = format!("vec_znx_rotate::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxRotate + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + res.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_rotate(-7, &mut res, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_rotate_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxRotateInplace + VecZnxRotateInplaceTmpBytes + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_rotate_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxRotateInplace + ModuleNew + VecZnxRotateInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut res: VecZnx> = VecZnx::alloc(n, cols, size); + + let mut scratch = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes()); + + // Fill a with random i64 + res.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_rotate_inplace(-7, &mut res, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/reference/vec_znx/sampling.rs b/poulpy-hal/src/reference/vec_znx/sampling.rs new file mode 100644 index 0000000..d29edab --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/sampling.rs @@ -0,0 +1,64 @@ +use crate::{ + layouts::{VecZnx, VecZnxToMut, ZnxInfos, ZnxViewMut}, + reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref}, + source::Source, +}; + +pub fn vec_znx_fill_uniform_ref(basek: usize, res: &mut R, res_col: usize, source: &mut Source) +where + R: VecZnxToMut, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + for j in 0..res.size() { + znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source) + } +} + +pub fn vec_znx_fill_normal_ref( + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + sigma: f64, + bound: f64, + source: &mut Source, +) where + R: VecZnxToMut, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + znx_fill_normal_f64_ref( + res.at_mut(res_col, limb), + sigma * scale, + bound * scale, + source, + ) +} + +pub fn vec_znx_add_normal_ref(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source) +where + R: VecZnxToMut, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + znx_add_normal_f64_ref( + res.at_mut(res_col, limb), + sigma * scale, + bound * scale, + source, + ) +} diff --git a/poulpy-hal/src/reference/vec_znx/shift.rs b/poulpy-hal/src/reference/vec_znx/shift.rs new file mode 100644 index 0000000..5b64d46 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/shift.rs @@ -0,0 +1,672 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; + +use crate::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxLsh, VecZnxLshInplace, VecZnxRsh, VecZnxRshInplace}, + layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::{ + vec_znx::vec_znx_copy, + znx::{ + ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, + ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, + ZnxZero, + }, + }, + source::Source, +}; + +pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_lsh_inplace(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) +where + R: VecZnxToMut, + ZNXARI: ZnxZero + + ZnxCopy + + ZnxNormalizeFirstStepInplace + + ZnxNormalizeMiddleStepInplace + + ZnxNormalizeFirstStepInplace + + ZnxNormalizeFinalStepInplace, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let n: usize = res.n(); + let cols: usize = res.cols(); + let size: usize = res.size(); + let steps: usize = k / basek; + let k_rem: usize = k % basek; + + if steps >= size { + for j in 0..size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + return; + } + + // Inplace shift of limbs by a k/basek + if steps > 0 { + let start: usize = n * res_col; + let end: usize = start + n; + let slice_size: usize = n * cols; + let res_raw: &mut [i64] = res.raw_mut(); + + (0..size - steps).for_each(|j| { + let (lhs, rhs) = res_raw.split_at_mut(slice_size * (j + steps)); + ZNXARI::znx_copy( + &mut lhs[start + j * slice_size..end + j * slice_size], + &rhs[start..end], + ); + }); + + for j in size - steps..size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + } + + // Inplace normalization with left shift of k % basek + if !k.is_multiple_of(basek) { + for j in (0..size - steps).rev() { + if j == size - steps - 1 { + ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry); + } else if j == 0 { + ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry); + } + } + } +} + +pub fn vec_znx_lsh(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxZero + ZnxNormalizeFirstStep + ZnxNormalizeMiddleStep + ZnxNormalizeFirstStep + ZnxCopy + ZnxNormalizeFinalStep, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + + let res_size: usize = res.size(); + let a_size = a.size(); + let steps: usize = k / basek; + let k_rem: usize = k % basek; + + if steps >= res_size.min(a_size) { + for j in 0..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + return; + } + + let min_size: usize = a_size.min(res_size) - steps; + + // Simply a left shifted normalization of limbs + // by k/basek and intra-limb by basek - k%basek + if !k.is_multiple_of(basek) { + for j in (0..min_size).rev() { + if j == min_size - 1 { + ZNXARI::znx_normalize_first_step( + basek, + k_rem, + res.at_mut(res_col, j), + a.at(a_col, j + steps), + carry, + ); + } else if j == 0 { + ZNXARI::znx_normalize_final_step( + basek, + k_rem, + res.at_mut(res_col, j), + a.at(a_col, j + steps), + carry, + ); + } else { + ZNXARI::znx_normalize_middle_step( + basek, + k_rem, + res.at_mut(res_col, j), + a.at(a_col, j + steps), + carry, + ); + } + } + } else { + // If k % basek = 0, then this is simply a copy. + for j in (0..min_size).rev() { + ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps)); + } + } + + // Zeroes bottom + for j in min_size..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} + +pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_rsh_inplace(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) +where + R: VecZnxToMut, + ZNXARI: ZnxZero + + ZnxCopy + + ZnxNormalizeFirstStepCarryOnly + + ZnxNormalizeMiddleStepCarryOnly + + ZnxNormalizeMiddleStep + + ZnxNormalizeMiddleStepInplace + + ZnxNormalizeFirstStepInplace + + ZnxNormalizeFinalStepInplace, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let n: usize = res.n(); + let cols: usize = res.cols(); + let size: usize = res.size(); + + let mut steps: usize = k / basek; + let k_rem: usize = k % basek; + + if k == 0 { + return; + } + + if steps >= size { + for j in 0..size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + return; + } + + let start: usize = n * res_col; + let end: usize = start + n; + let slice_size: usize = n * cols; + + if !k.is_multiple_of(basek) { + // We rsh by an additional basek and then lsh by basek-k + // Allows to re-use efficient normalization code, avoids + // avoids overflows & produce output that is normalized + steps += 1; + + // All limbs of a that would fall outside of the limbs of res are discarded, + // but the carry still need to be computed. + (size - steps..size).rev().for_each(|j| { + if j == size - 1 { + ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry); + } + }); + + // Continues with shifted normalization + let res_raw: &mut [i64] = res.raw_mut(); + (steps..size).rev().for_each(|j| { + let (lhs, rhs) = res_raw.split_at_mut(slice_size * j); + let rhs_slice: &mut [i64] = &mut rhs[start..end]; + let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end]; + ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry); + }); + + // Propagates carry on the rest of the limbs of res + for j in (0..steps).rev() { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + if j == 0 { + ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry); + } + } + } else { + // Shift by multiples of basek + let res_raw: &mut [i64] = res.raw_mut(); + (steps..size).rev().for_each(|j| { + let (lhs, rhs) = res_raw.split_at_mut(slice_size * j); + ZNXARI::znx_copy( + &mut rhs[start..end], + &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end], + ); + }); + + // Zeroes the top + (0..steps).for_each(|j| { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + }); + } +} + +pub fn vec_znx_rsh(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64]) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxZero + + ZnxCopy + + ZnxNormalizeFirstStepCarryOnly + + ZnxNormalizeMiddleStepCarryOnly + + ZnxNormalizeFirstStep + + ZnxNormalizeMiddleStep + + ZnxNormalizeMiddleStepInplace + + ZnxNormalizeFirstStepInplace + + ZnxNormalizeFinalStepInplace, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: VecZnx<&[u8]> = a.to_ref(); + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let mut steps: usize = k / basek; + let k_rem: usize = k % basek; + + if k == 0 { + vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col); + return; + } + + if steps >= res_size { + for j in 0..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + return; + } + + if !k.is_multiple_of(basek) { + // We rsh by an additional basek and then lsh by basek-k + // Allows to re-use efficient normalization code, avoids + // avoids overflows & produce output that is normalized + steps += 1; + + // All limbs of a that are moved outside of the limbs of res are discarded, + // but the carry still need to be computed. + for j in (res_size..a_size + steps).rev() { + if j == a_size + steps - 1 { + ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry); + } else { + ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry); + } + } + + // Avoids over flow of limbs of res + let min_size: usize = res_size.min(a_size + steps); + + // Zeroes lower limbs of res if a_size + steps < res_size + (min_size..res_size).for_each(|j| { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + }); + + // Continues with shifted normalization + for j in (steps..min_size).rev() { + // Case if no limb of a was previously discarded + if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 { + ZNXARI::znx_normalize_first_step( + basek, + basek - k_rem, + res.at_mut(res_col, j), + a.at(a_col, j - steps), + carry, + ); + } else { + ZNXARI::znx_normalize_middle_step( + basek, + basek - k_rem, + res.at_mut(res_col, j), + a.at(a_col, j - steps), + carry, + ); + } + } + + // Propagates carry on the rest of the limbs of res + for j in (0..steps).rev() { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + if j == 0 { + ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry); + } else { + ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry); + } + } + } else { + let min_size: usize = res_size.min(a_size + steps); + + // Zeroes the top + (0..steps).for_each(|j| { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + }); + + // Shift a into res, up to the maximum + for j in (steps..min_size).rev() { + ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j - steps)); + } + + // Zeroes bottom if a_size + steps < res_size + (min_size..res_size).for_each(|j| { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + }); + } +} + +pub fn bench_vec_znx_lsh_inplace(c: &mut Criterion, label: &str) +where + Module: ModuleNew + VecZnxLshInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_lsh_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxLshInplace + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let basek: usize = 50; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut b: VecZnx> = VecZnx::alloc(n, cols, size); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(n * size_of::()); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + b.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_lsh(c: &mut Criterion, label: &str) +where + Module: VecZnxLsh + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_lsh::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxLsh + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let basek: usize = 50; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res: VecZnx> = VecZnx::alloc(n, cols, size); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(n * size_of::()); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + res.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_rsh_inplace(c: &mut Criterion, label: &str) +where + Module: VecZnxRshInplace + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_rsh_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxRshInplace + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let basek: usize = 50; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut b: VecZnx> = VecZnx::alloc(n, cols, size); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(n * size_of::()); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + b.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_rsh(c: &mut Criterion, label: &str) +where + Module: VecZnxRsh + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let group_name: String = format!("vec_znx_rsh::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxRsh + ModuleNew, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let basek: usize = 50; + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res: VecZnx> = VecZnx::alloc(n, cols, size); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(n * size_of::()); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + res.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow()); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +#[cfg(test)] +mod tests { + use crate::{ + layouts::{FillUniform, VecZnx, ZnxView}, + reference::{ + vec_znx::{ + vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace, + vec_znx_sub_ab_inplace, + }, + znx::ZnxRef, + }, + source::Source, + }; + + #[test] + fn test_vec_znx_lsh() { + let n: usize = 8; + let cols: usize = 2; + let size: usize = 7; + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, size); + + let mut source: Source = Source::new([0u8; 32]); + + let mut carry: Vec = vec![0i64; n]; + + let basek: usize = 50; + + for k in 0..256 { + a.fill_uniform(50, &mut source); + + for i in 0..cols { + vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry); + vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i); + } + + for i in 0..cols { + vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry); + vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry); + vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry); + } + + assert_eq!(res_ref, res_test); + } + } + + #[test] + fn test_vec_znx_rsh() { + let n: usize = 8; + let cols: usize = 2; + + let res_size: usize = 7; + + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let mut carry: Vec = vec![0i64; n]; + + let basek: usize = 50; + + let mut source: Source = Source::new([0u8; 32]); + + let zero: Vec = vec![0i64; n]; + + for a_size in [res_size - 1, res_size, res_size + 1] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + + for k in 0..res_size * basek { + a.fill_uniform(50, &mut source); + + for i in 0..cols { + vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry); + vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i); + } + + res_test.fill_uniform(50, &mut source); + + for j in 0..cols { + vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry); + vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry); + } + + for j in 0..cols { + vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry); + vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry); + } + + // Case where res has enough to fully store a right shifted without any loss + // In this case we can check exact equality. + if a_size + k.div_ceil(basek) <= res_size { + assert_eq!(res_ref, res_test); + + for i in 0..cols { + for j in 0..a_size { + assert_eq!(res_ref.at(i, j), a.at(i, j), "r0 {} {}", i, j); + assert_eq!(res_test.at(i, j), a.at(i, j), "r1 {} {}", i, j); + } + + for j in a_size..res_size { + assert_eq!(res_ref.at(i, j), zero, "r0 {} {}", i, j); + assert_eq!(res_test.at(i, j), zero, "r1 {} {}", i, j); + } + } + // Some loss occures, either because a initially has more precision than res + // or because the storage of the right shift of a requires more precision than + // res. + } else { + for j in 0..cols { + vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j); + vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j); + + vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry); + vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry); + + assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64); + assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64); + } + } + } + } + } +} diff --git a/poulpy-hal/src/reference/vec_znx/split_ring.rs b/poulpy-hal/src/reference/vec_znx/split_ring.rs new file mode 100644 index 0000000..adb5e13 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/split_ring.rs @@ -0,0 +1,62 @@ +use crate::{ + layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ZnxRotate, ZnxSwitchRing, ZnxZero}, +}; + +pub fn vec_znx_split_ring_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn vec_znx_split_ring(res: &mut [R], res_col: usize, a: &A, a_col: usize, tmp: &mut [i64]) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxSwitchRing + ZnxRotate + ZnxZero, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let a_size = a.size(); + + let (n_in, n_out) = (a.n(), res[0].to_mut().n()); + + #[cfg(debug_assertions)] + { + assert_eq!(tmp.len(), a.n()); + + assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + + res[1..].iter_mut().for_each(|bi| { + assert_eq!( + bi.to_mut().n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + assert!(n_in.is_multiple_of(n_out)); + assert_eq!(res.len(), n_in / n_out); + } + + res.iter_mut().enumerate().for_each(|(i, bi)| { + let mut bi: VecZnx<&mut [u8]> = bi.to_mut(); + + let min_size = bi.size().min(a_size); + + if i == 0 { + for j in 0..min_size { + ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), a.at(a_col, j)); + } + } else { + for j in 0..min_size { + ZNXARI::znx_rotate(-(i as i64), tmp, a.at(a_col, j)); + ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), tmp); + } + } + + for j in min_size..bi.size() { + ZNXARI::znx_zero(bi.at_mut(res_col, j)); + } + }) +} diff --git a/poulpy-hal/src/reference/vec_znx/sub.rs b/poulpy-hal/src/reference/vec_znx/sub.rs new file mode 100644 index 0000000..e9341ff --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/sub.rs @@ -0,0 +1,250 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion}; + +use crate::{ + api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace}, + layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl}, + reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero}, + source::Source, +}; + +pub fn vec_znx_sub(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + ZNXARI: ZnxSub + ZnxNegate + ZnxZero + ZnxCopy, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + assert_eq!(b.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + let b_size: usize = b.size(); + + if a_size <= b_size { + let sum_size: usize = a_size.min(res_size); + let cpy_size: usize = b_size.min(res_size); + + for j in 0..sum_size { + ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); + } + + for j in sum_size..cpy_size { + ZNXARI::znx_negate(res.at_mut(res_col, j), b.at(b_col, j)); + } + + for j in cpy_size..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + } else { + let sum_size: usize = b_size.min(res_size); + let cpy_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j)); + } + + for j in sum_size..cpy_size { + ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in cpy_size..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + } +} + +pub fn vec_znx_sub_ab_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxSubABInplace, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let sum_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + } +} + +pub fn vec_znx_sub_ba_inplace(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxSubBAInplace + ZnxNegateInplace, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), res.n()); + } + + let res_size: usize = res.size(); + let a_size: usize = a.size(); + + let sum_size: usize = a_size.min(res_size); + + for j in 0..sum_size { + ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in sum_size..res_size { + ZNXARI::znx_negate_inplace(res.at_mut(res_col, j)); + } +} + +pub fn bench_vec_znx_sub(c: &mut Criterion, label: &str) +where + B: Backend + ModuleNewImpl + VecZnxSubImpl, +{ + let group_name: String = format!("vec_znx_sub::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxSub + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut b: VecZnx> = VecZnx::alloc(n, cols, size); + let mut c: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + b.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_sub(&mut c, i, &a, i, &b, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],)); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_sub_ab_inplace(c: &mut Criterion, label: &str) +where + B: Backend + ModuleNewImpl + VecZnxSubABInplaceImpl, +{ + let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxSubABInplace + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut b: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + b.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_sub_ab_inplace(&mut b, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +pub fn bench_vec_znx_sub_ba_inplace(c: &mut Criterion, label: &str) +where + B: Backend + ModuleNewImpl + VecZnxSubBAInplaceImpl, +{ + let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label); + + let mut group = c.benchmark_group(group_name); + + fn runner(params: [usize; 3]) -> impl FnMut() + where + Module: VecZnxSubBAInplace + ModuleNew, + { + let n: usize = 1 << params[0]; + let cols: usize = params[1]; + let size: usize = params[2]; + + let module: Module = Module::::new(n as u64); + + let mut source: Source = Source::new([0u8; 32]); + + let mut a: VecZnx> = VecZnx::alloc(n, cols, size); + let mut b: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + a.fill_uniform(50, &mut source); + b.fill_uniform(50, &mut source); + + move || { + for i in 0..cols { + module.vec_znx_sub_ba_inplace(&mut b, i, &a, i); + } + black_box(()); + } + } + + for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] { + let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2])); + let mut runner = runner::(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} diff --git a/poulpy-hal/src/reference/vec_znx/sub_scalar.rs b/poulpy-hal/src/reference/vec_znx/sub_scalar.rs new file mode 100644 index 0000000..04d405e --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/sub_scalar.rs @@ -0,0 +1,58 @@ +use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef}; +use crate::{ + layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero}, +}; + +pub fn vec_znx_sub_scalar(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize) +where + R: VecZnxToMut, + A: ScalarZnxToRef, + B: VecZnxToRef, + ZNXARI: ZnxSub + ZnxZero, +{ + let a: ScalarZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let min_size: usize = b.size().min(res.size()); + + #[cfg(debug_assertions)] + { + assert!( + b_limb < min_size, + "b_limb: {} > min_size: {}", + b_limb, + min_size + ); + } + + for j in 0..min_size { + if j == b_limb { + ZNXARI::znx_sub(res.at_mut(res_col, j), b.at(b_col, j), a.at(a_col, 0)); + } else { + res.at_mut(res_col, j).copy_from_slice(b.at(b_col, j)); + } + } + + for j in min_size..res.size() { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} + +pub fn vec_znx_sub_scalar_inplace(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: ScalarZnxToRef, + ZNXARI: ZnxSubABInplace, +{ + let a: ScalarZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert!(res_limb < res.size()); + } + + ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0)); +} diff --git a/poulpy-hal/src/reference/vec_znx/switch_ring.rs b/poulpy-hal/src/reference/vec_znx/switch_ring.rs new file mode 100644 index 0000000..c275886 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/switch_ring.rs @@ -0,0 +1,37 @@ +use crate::{ + layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, + reference::{ + vec_znx::vec_znx_copy, + znx::{ZnxCopy, ZnxSwitchRing, ZnxZero}, + }, +}; + +/// Maps between negacyclic rings by changing the polynomial degree. +/// Up: Z[X]/(X^N+1) -> Z[X]/(X^{2^d N}+1) via X ↦ X^{2^d} +/// Down: Z[X]/(X^N+1) -> Z[X]/(X^{N/2^d}+1) by folding indices. +pub fn vec_znx_switch_ring(res: &mut R, res_col: usize, a: &A, a_col: usize) +where + R: VecZnxToMut, + A: VecZnxToRef, + ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxZero, +{ + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_in, n_out) = (a.n(), res.n()); + + if n_in == n_out { + vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col); + return; + } + + let min_size: usize = a.size().min(res.size()); + + for j in 0..min_size { + ZNXARI::znx_switch_ring(res.at_mut(res_col, j), a.at(a_col, j)); + } + + for j in min_size..res.size() { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} diff --git a/poulpy-hal/src/reference/zn/mod.rs b/poulpy-hal/src/reference/zn/mod.rs new file mode 100644 index 0000000..d4838c3 --- /dev/null +++ b/poulpy-hal/src/reference/zn/mod.rs @@ -0,0 +1,5 @@ +mod normalization; +mod sampling; + +pub use normalization::*; +pub use sampling::*; diff --git a/poulpy-hal/src/reference/zn/normalization.rs b/poulpy-hal/src/reference/zn/normalization.rs new file mode 100644 index 0000000..4412369 --- /dev/null +++ b/poulpy-hal/src/reference/zn/normalization.rs @@ -0,0 +1,72 @@ +use crate::{ + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnNormalizeTmpBytes}, + layouts::{Backend, Module, ScratchOwned, Zn, ZnToMut, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::{ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef}, + source::Source, +}; + +pub fn zn_normalize_tmp_bytes(n: usize) -> usize { + n * size_of::() +} + +pub fn zn_normalize_inplace(n: usize, basek: usize, res: &mut R, res_col: usize, carry: &mut [i64]) +where + R: ZnToMut, + ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace, +{ + let mut res: Zn<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(carry.len(), res.n()); + } + + let res_size: usize = res.size(); + + for j in (0..res_size).rev() { + let out = &mut res.at_mut(res_col, j)[..n]; + + if j == res_size - 1 { + ARI::znx_normalize_first_step_inplace(basek, 0, out, carry); + } else if j == 0 { + ARI::znx_normalize_final_step_inplace(basek, 0, out, carry); + } else { + ARI::znx_normalize_middle_step_inplace(basek, 0, out, carry); + } + } +} + +pub fn test_zn_normalize_inplace(module: &Module) +where + Module: ZnNormalizeInplace + ZnNormalizeTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let basek: usize = 12; + + let n = 33; + + let mut carry: Vec = vec![0i64; zn_normalize_tmp_bytes(n)]; + + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.zn_normalize_tmp_bytes(module.n())); + + for res_size in [1, 2, 6, 11] { + let mut res_0: Zn> = Zn::alloc(n, cols, res_size); + let mut res_1: Zn> = Zn::alloc(n, cols, res_size); + + res_0 + .raw_mut() + .iter_mut() + .for_each(|x| *x = source.next_i32() as i64); + res_1.raw_mut().copy_from_slice(res_0.raw()); + + // Reference + for i in 0..cols { + zn_normalize_inplace::<_, ZnxRef>(n, basek, &mut res_0, i, &mut carry); + module.zn_normalize_inplace(n, basek, &mut res_1, i, scratch.borrow()); + } + + assert_eq!(res_0.raw(), res_1.raw()); + } +} diff --git a/poulpy-hal/src/reference/zn/sampling.rs b/poulpy-hal/src/reference/zn/sampling.rs new file mode 100644 index 0000000..9c46f7a --- /dev/null +++ b/poulpy-hal/src/reference/zn/sampling.rs @@ -0,0 +1,75 @@ +use crate::{ + layouts::{Zn, ZnToMut, ZnxInfos, ZnxViewMut}, + reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref}, + source::Source, +}; + +pub fn zn_fill_uniform(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source) +where + R: ZnToMut, +{ + let mut res: Zn<&mut [u8]> = res.to_mut(); + for j in 0..res.size() { + znx_fill_uniform_ref(basek, &mut res.at_mut(res_col, j)[..n], source) + } +} + +#[allow(clippy::too_many_arguments)] +pub fn zn_fill_normal( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, +) where + R: ZnToMut, +{ + let mut res: Zn<&mut [u8]> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + znx_fill_normal_f64_ref( + &mut res.at_mut(res_col, limb)[..n], + sigma * scale, + bound * scale, + source, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn zn_add_normal( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, +) where + R: ZnToMut, +{ + let mut res: Zn<&mut [u8]> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64; + znx_add_normal_f64_ref( + &mut res.at_mut(res_col, limb)[..n], + sigma * scale, + bound * scale, + source, + ) +} diff --git a/poulpy-hal/src/reference/znx/add.rs b/poulpy-hal/src/reference/znx/add.rs new file mode 100644 index 0000000..55e0a1e --- /dev/null +++ b/poulpy-hal/src/reference/znx/add.rs @@ -0,0 +1,25 @@ +#[inline(always)] +pub fn znx_add_ref(res: &mut [i64], a: &[i64], b: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + assert_eq!(res.len(), b.len()); + } + + let n: usize = res.len(); + for i in 0..n { + res[i] = a[i] + b[i]; + } +} + +pub fn znx_add_inplace_ref(res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + let n: usize = res.len(); + for i in 0..n { + res[i] += a[i]; + } +} diff --git a/poulpy-hal/src/reference/znx/arithmetic_ref.rs b/poulpy-hal/src/reference/znx/arithmetic_ref.rs new file mode 100644 index 0000000..ba21ede --- /dev/null +++ b/poulpy-hal/src/reference/znx/arithmetic_ref.rs @@ -0,0 +1,153 @@ +use crate::reference::znx::{ + ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep, + ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, + ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubABInplace, + ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, + add::{znx_add_inplace_ref, znx_add_ref}, + automorphism::znx_automorphism_ref, + copy::znx_copy_ref, + neg::{znx_negate_inplace_ref, znx_negate_ref}, + normalization::{ + znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_carry_only_ref, + znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_carry_only_ref, + znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, + }, + sub::{znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref}, + switch_ring::znx_switch_ring_ref, + zero::znx_zero_ref, +}; + +pub struct ZnxRef {} + +impl ZnxAdd for ZnxRef { + #[inline(always)] + fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) { + znx_add_ref(res, a, b); + } +} + +impl ZnxAddInplace for ZnxRef { + #[inline(always)] + fn znx_add_inplace(res: &mut [i64], a: &[i64]) { + znx_add_inplace_ref(res, a); + } +} + +impl ZnxSub for ZnxRef { + #[inline(always)] + fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) { + znx_sub_ref(res, a, b); + } +} + +impl ZnxSubABInplace for ZnxRef { + #[inline(always)] + fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_ab_inplace_ref(res, a); + } +} + +impl ZnxSubBAInplace for ZnxRef { + #[inline(always)] + fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) { + znx_sub_ba_inplace_ref(res, a); + } +} + +impl ZnxAutomorphism for ZnxRef { + #[inline(always)] + fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) { + znx_automorphism_ref(p, res, a); + } +} + +impl ZnxCopy for ZnxRef { + #[inline(always)] + fn znx_copy(res: &mut [i64], a: &[i64]) { + znx_copy_ref(res, a); + } +} + +impl ZnxNegate for ZnxRef { + #[inline(always)] + fn znx_negate(res: &mut [i64], src: &[i64]) { + znx_negate_ref(res, src); + } +} + +impl ZnxNegateInplace for ZnxRef { + #[inline(always)] + fn znx_negate_inplace(res: &mut [i64]) { + znx_negate_inplace_ref(res); + } +} + +impl ZnxZero for ZnxRef { + #[inline(always)] + fn znx_zero(res: &mut [i64]) { + znx_zero_ref(res); + } +} + +impl ZnxSwitchRing for ZnxRef { + #[inline(always)] + fn znx_switch_ring(res: &mut [i64], a: &[i64]) { + znx_switch_ring_ref(res, a); + } +} + +impl ZnxNormalizeFinalStep for ZnxRef { + #[inline(always)] + fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_final_step_ref(basek, lsh, x, a, carry); + } +} + +impl ZnxNormalizeFinalStepInplace for ZnxRef { + #[inline(always)] + fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_final_step_inplace_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeFirstStep for ZnxRef { + #[inline(always)] + fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_ref(basek, lsh, x, a, carry); + } +} + +impl ZnxNormalizeFirstStepCarryOnly for ZnxRef { + #[inline(always)] + fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeFirstStepInplace for ZnxRef { + #[inline(always)] + fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_first_step_inplace_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeMiddleStep for ZnxRef { + #[inline(always)] + fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_ref(basek, lsh, x, a, carry); + } +} + +impl ZnxNormalizeMiddleStepCarryOnly for ZnxRef { + #[inline(always)] + fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry); + } +} + +impl ZnxNormalizeMiddleStepInplace for ZnxRef { + #[inline(always)] + fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry); + } +} diff --git a/poulpy-hal/src/reference/znx/automorphism.rs b/poulpy-hal/src/reference/znx/automorphism.rs new file mode 100644 index 0000000..6ef9e84 --- /dev/null +++ b/poulpy-hal/src/reference/znx/automorphism.rs @@ -0,0 +1,21 @@ +pub fn znx_automorphism_ref(p: i64, res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + let n: usize = res.len(); + let mut k: usize = 0usize; + let mask: usize = 2 * n - 1; + let p_2n = (p & mask as i64) as usize; + + res[0] = a[0]; + for ai in a.iter().take(n).skip(1) { + k = (k + p_2n) & mask; + if k < n { + res[k] = *ai + } else { + res[k - n] = -*ai + } + } +} diff --git a/poulpy-hal/src/reference/znx/copy.rs b/poulpy-hal/src/reference/znx/copy.rs new file mode 100644 index 0000000..760d162 --- /dev/null +++ b/poulpy-hal/src/reference/znx/copy.rs @@ -0,0 +1,8 @@ +#[inline(always)] +pub fn znx_copy_ref(res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()) + } + res.copy_from_slice(a); +} diff --git a/poulpy-hal/src/reference/znx/mod.rs b/poulpy-hal/src/reference/znx/mod.rs new file mode 100644 index 0000000..9659e7d --- /dev/null +++ b/poulpy-hal/src/reference/znx/mod.rs @@ -0,0 +1,104 @@ +mod add; +mod arithmetic_ref; +mod automorphism; +mod copy; +mod neg; +mod normalization; +mod rotate; +mod sampling; +mod sub; +mod switch_ring; +mod zero; + +pub use add::*; +pub use arithmetic_ref::*; +pub use automorphism::*; +pub use copy::*; +pub use neg::*; +pub use normalization::*; +pub use rotate::*; +pub use sub::*; +pub use switch_ring::*; +pub use zero::*; + +pub use sampling::*; + +pub trait ZnxAdd { + fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]); +} + +pub trait ZnxAddInplace { + fn znx_add_inplace(res: &mut [i64], a: &[i64]); +} + +pub trait ZnxSub { + fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]); +} + +pub trait ZnxSubABInplace { + fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]); +} + +pub trait ZnxSubBAInplace { + fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]); +} + +pub trait ZnxAutomorphism { + fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]); +} + +pub trait ZnxCopy { + fn znx_copy(res: &mut [i64], a: &[i64]); +} + +pub trait ZnxNegate { + fn znx_negate(res: &mut [i64], src: &[i64]); +} + +pub trait ZnxNegateInplace { + fn znx_negate_inplace(res: &mut [i64]); +} + +pub trait ZnxRotate { + fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]); +} + +pub trait ZnxZero { + fn znx_zero(res: &mut [i64]); +} + +pub trait ZnxSwitchRing { + fn znx_switch_ring(res: &mut [i64], a: &[i64]); +} + +pub trait ZnxNormalizeFirstStepCarryOnly { + fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]); +} + +pub trait ZnxNormalizeFirstStepInplace { + fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]); +} + +pub trait ZnxNormalizeFirstStep { + fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]); +} + +pub trait ZnxNormalizeMiddleStepCarryOnly { + fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]); +} + +pub trait ZnxNormalizeMiddleStepInplace { + fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]); +} + +pub trait ZnxNormalizeMiddleStep { + fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]); +} + +pub trait ZnxNormalizeFinalStepInplace { + fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]); +} + +pub trait ZnxNormalizeFinalStep { + fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]); +} diff --git a/poulpy-hal/src/reference/znx/neg.rs b/poulpy-hal/src/reference/znx/neg.rs new file mode 100644 index 0000000..f88df9b --- /dev/null +++ b/poulpy-hal/src/reference/znx/neg.rs @@ -0,0 +1,18 @@ +#[inline(always)] +pub fn znx_negate_ref(res: &mut [i64], src: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), src.len()) + } + + for i in 0..res.len() { + res[i] = -src[i] + } +} + +#[inline(always)] +pub fn znx_negate_inplace_ref(res: &mut [i64]) { + for value in res { + *value = -*value + } +} diff --git a/poulpy-hal/src/reference/znx/normalization.rs b/poulpy-hal/src/reference/znx/normalization.rs new file mode 100644 index 0000000..e9f57cf --- /dev/null +++ b/poulpy-hal/src/reference/znx/normalization.rs @@ -0,0 +1,199 @@ +use itertools::izip; + +#[inline(always)] +pub fn get_digit(basek: usize, x: i64) -> i64 { + (x << (u64::BITS - basek as u32)) >> (u64::BITS - basek as u32) +} + +#[inline(always)] +pub fn get_carry(basek: usize, x: i64, digit: i64) -> i64 { + (x.wrapping_sub(digit)) >> basek +} + +#[inline(always)] +pub fn znx_normalize_first_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert!(x.len() <= carry.len()); + assert!(lsh < basek); + } + + if lsh == 0 { + x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { + *c = get_carry(basek, *x, get_digit(basek, *x)); + }); + } else { + let basek_lsh: usize = basek - lsh; + x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { + *c = get_carry(basek_lsh, *x, get_digit(basek_lsh, *x)); + }); + } +} + +#[inline(always)] +pub fn znx_normalize_first_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert!(x.len() <= carry.len()); + assert!(lsh < basek); + } + + if lsh == 0 { + x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { + let digit: i64 = get_digit(basek, *x); + *c = get_carry(basek, *x, digit); + *x = digit; + }); + } else { + let basek_lsh: usize = basek - lsh; + x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { + let digit: i64 = get_digit(basek_lsh, *x); + *c = get_carry(basek_lsh, *x, digit); + *x = digit << lsh; + }); + } +} + +#[inline(always)] +pub fn znx_normalize_first_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), a.len()); + assert!(x.len() <= carry.len()); + assert!(lsh < basek); + } + + if lsh == 0 { + izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { + let digit: i64 = get_digit(basek, *a); + *c = get_carry(basek, *a, digit); + *x = digit; + }); + } else { + let basek_lsh: usize = basek - lsh; + izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { + let digit: i64 = get_digit(basek_lsh, *a); + *c = get_carry(basek_lsh, *a, digit); + *x = digit << lsh; + }); + } +} + +#[inline(always)] +pub fn znx_normalize_middle_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert!(x.len() <= carry.len()); + assert!(lsh < basek); + } + if lsh == 0 { + x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { + let digit: i64 = get_digit(basek, *x); + let carry: i64 = get_carry(basek, *x, digit); + let digit_plus_c: i64 = digit + *c; + *c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c)); + }); + } else { + let basek_lsh: usize = basek - lsh; + x.iter().zip(carry.iter_mut()).for_each(|(x, c)| { + let digit: i64 = get_digit(basek_lsh, *x); + let carry: i64 = get_carry(basek_lsh, *x, digit); + let digit_plus_c: i64 = (digit << lsh) + *c; + *c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c)); + }); + } +} + +#[inline(always)] +pub fn znx_normalize_middle_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert!(x.len() <= carry.len()); + assert!(lsh < basek); + } + if lsh == 0 { + x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { + let digit: i64 = get_digit(basek, *x); + let carry: i64 = get_carry(basek, *x, digit); + let digit_plus_c: i64 = digit + *c; + *x = get_digit(basek, digit_plus_c); + *c = carry + get_carry(basek, digit_plus_c, *x); + }); + } else { + let basek_lsh: usize = basek - lsh; + x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { + let digit: i64 = get_digit(basek_lsh, *x); + let carry: i64 = get_carry(basek_lsh, *x, digit); + let digit_plus_c: i64 = (digit << lsh) + *c; + *x = get_digit(basek, digit_plus_c); + *c = carry + get_carry(basek, digit_plus_c, *x); + }); + } +} + +#[inline(always)] +pub fn znx_normalize_middle_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(x.len(), a.len()); + assert!(x.len() <= carry.len()); + assert!(lsh < basek); + } + if lsh == 0 { + izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { + let digit: i64 = get_digit(basek, *a); + let carry: i64 = get_carry(basek, *a, digit); + let digit_plus_c: i64 = digit + *c; + *x = get_digit(basek, digit_plus_c); + *c = carry + get_carry(basek, digit_plus_c, *x); + }); + } else { + let basek_lsh: usize = basek - lsh; + izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { + let digit: i64 = get_digit(basek_lsh, *a); + let carry: i64 = get_carry(basek_lsh, *a, digit); + let digit_plus_c: i64 = (digit << lsh) + *c; + *x = get_digit(basek, digit_plus_c); + *c = carry + get_carry(basek, digit_plus_c, *x); + }); + } +} + +#[inline(always)] +pub fn znx_normalize_final_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert!(x.len() <= carry.len()); + assert!(lsh < basek); + } + + if lsh == 0 { + x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { + *x = get_digit(basek, get_digit(basek, *x) + *c); + }); + } else { + let basek_lsh: usize = basek - lsh; + x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| { + *x = get_digit(basek, (get_digit(basek_lsh, *x) << lsh) + *c); + }); + } +} + +#[inline(always)] +pub fn znx_normalize_final_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) { + #[cfg(debug_assertions)] + { + assert!(x.len() <= carry.len()); + assert!(lsh < basek); + } + if lsh == 0 { + izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { + *x = get_digit(basek, get_digit(basek, *a) + *c); + }); + } else { + let basek_lsh: usize = basek - lsh; + izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| { + *x = get_digit(basek, (get_digit(basek_lsh, *a) << lsh) + *c); + }); + } +} diff --git a/poulpy-hal/src/reference/znx/rotate.rs b/poulpy-hal/src/reference/znx/rotate.rs new file mode 100644 index 0000000..2831dd0 --- /dev/null +++ b/poulpy-hal/src/reference/znx/rotate.rs @@ -0,0 +1,26 @@ +use crate::reference::znx::{ZnxCopy, ZnxNegate}; + +pub fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), src.len()); + } + + let n: usize = res.len(); + + let mp_2n: usize = (p & (2 * n as i64 - 1)) as usize; // -p % 2n + let mp_1n: usize = mp_2n & (n - 1); // -p % n + let mp_1n_neg: usize = n - mp_1n; // p % n + let neg_first: bool = mp_2n < n; + + let (dst1, dst2) = res.split_at_mut(mp_1n); + let (src1, src2) = src.split_at(mp_1n_neg); + + if neg_first { + ZNXARI::znx_negate(dst1, src2); + ZNXARI::znx_copy(dst2, src1); + } else { + ZNXARI::znx_copy(dst1, src2); + ZNXARI::znx_negate(dst2, src1); + } +} diff --git a/poulpy-hal/src/reference/znx/sampling.rs b/poulpy-hal/src/reference/znx/sampling.rs new file mode 100644 index 0000000..feaa393 --- /dev/null +++ b/poulpy-hal/src/reference/znx/sampling.rs @@ -0,0 +1,53 @@ +use rand_distr::{Distribution, Normal}; + +use crate::source::Source; + +pub fn znx_fill_uniform_ref(basek: usize, res: &mut [i64], source: &mut Source) { + let pow2k: u64 = 1 << basek; + let mask: u64 = pow2k - 1; + let pow2k_half: i64 = (pow2k >> 1) as i64; + res.iter_mut() + .for_each(|xi| *xi = (source.next_u64n(pow2k, mask) as i64) - pow2k_half) +} + +pub fn znx_fill_dist_f64_ref>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) { + res.iter_mut().for_each(|xi| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *xi = dist_f64.round() as i64 + }) +} + +pub fn znx_add_dist_f64_ref>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) { + res.iter_mut().for_each(|xi| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *xi += dist_f64.round() as i64 + }) +} + +pub fn znx_fill_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) { + let normal: Normal = Normal::new(0.0, sigma).unwrap(); + res.iter_mut().for_each(|xi| { + let mut dist_f64: f64 = normal.sample(source); + while dist_f64.abs() > bound { + dist_f64 = normal.sample(source) + } + *xi = dist_f64.round() as i64 + }) +} + +pub fn znx_add_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) { + let normal: Normal = Normal::new(0.0, sigma).unwrap(); + res.iter_mut().for_each(|xi| { + let mut dist_f64: f64 = normal.sample(source); + while dist_f64.abs() > bound { + dist_f64 = normal.sample(source) + } + *xi += dist_f64.round() as i64 + }) +} diff --git a/poulpy-hal/src/reference/znx/sub.rs b/poulpy-hal/src/reference/znx/sub.rs new file mode 100644 index 0000000..7cb4599 --- /dev/null +++ b/poulpy-hal/src/reference/znx/sub.rs @@ -0,0 +1,36 @@ +pub fn znx_sub_ref(res: &mut [i64], a: &[i64], b: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + assert_eq!(res.len(), b.len()); + } + + let n: usize = res.len(); + for i in 0..n { + res[i] = a[i] - b[i]; + } +} + +pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + let n: usize = res.len(); + for i in 0..n { + res[i] -= a[i]; + } +} + +pub fn znx_sub_ba_inplace_ref(res: &mut [i64], a: &[i64]) { + #[cfg(debug_assertions)] + { + assert_eq!(res.len(), a.len()); + } + + let n: usize = res.len(); + for i in 0..n { + res[i] = a[i] - res[i]; + } +} diff --git a/poulpy-hal/src/reference/znx/switch_ring.rs b/poulpy-hal/src/reference/znx/switch_ring.rs new file mode 100644 index 0000000..52a750f --- /dev/null +++ b/poulpy-hal/src/reference/znx/switch_ring.rs @@ -0,0 +1,29 @@ +use crate::reference::znx::{copy::znx_copy_ref, zero::znx_zero_ref}; + +pub fn znx_switch_ring_ref(res: &mut [i64], a: &[i64]) { + let (n_in, n_out) = (a.len(), res.len()); + + #[cfg(debug_assertions)] + { + assert!(n_in.is_power_of_two()); + assert!(n_in.max(n_out).is_multiple_of(n_in.min(n_out))) + } + + if n_in == n_out { + znx_copy_ref(res, a); + return; + } + + let (gap_in, gap_out): (usize, usize); + if n_in > n_out { + (gap_in, gap_out) = (n_in / n_out, 1) + } else { + (gap_in, gap_out) = (1, n_out / n_in); + znx_zero_ref(res); + } + + res.iter_mut() + .step_by(gap_out) + .zip(a.iter().step_by(gap_in)) + .for_each(|(x_out, x_in)| *x_out = *x_in); +} diff --git a/poulpy-hal/src/reference/znx/zero.rs b/poulpy-hal/src/reference/znx/zero.rs new file mode 100644 index 0000000..16303c5 --- /dev/null +++ b/poulpy-hal/src/reference/znx/zero.rs @@ -0,0 +1,3 @@ +pub fn znx_zero_ref(res: &mut [i64]) { + res.fill(0); +} diff --git a/poulpy-hal/src/source.rs b/poulpy-hal/src/source.rs index 5107525..4852b31 100644 --- a/poulpy-hal/src/source.rs +++ b/poulpy-hal/src/source.rs @@ -39,6 +39,12 @@ impl Source { min + ((self.next_u64() << 11 >> 11) as f64) / MAXF64 * (max - min) } + #[inline(always)] + pub fn next_i32(&mut self) -> i32 { + self.next_u32() as i32 + } + + #[inline(always)] pub fn next_i64(&mut self) -> i64 { self.next_u64() as i64 } diff --git a/poulpy-hal/src/test_suite/mod.rs b/poulpy-hal/src/test_suite/mod.rs new file mode 100644 index 0000000..bbdaf05 --- /dev/null +++ b/poulpy-hal/src/test_suite/mod.rs @@ -0,0 +1,68 @@ +pub mod serialization; +pub mod svp; +pub mod vec_znx; +pub mod vec_znx_big; +pub mod vec_znx_dft; +pub mod vmp; + +#[macro_export] +macro_rules! backend_test_suite { + ( + mod $modname:ident, + backend = $backend:ty, + size = $size:expr, + tests = { + $( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)? + } + ) => { + mod $modname { + use poulpy_hal::{api::ModuleNew, layouts::Module}; + + use once_cell::sync::Lazy; + + static MODULE: Lazy> = + Lazy::new(|| Module::<$backend>::new($size)); + + $( + $(#[$attr])* + #[test] + fn $test_name() { + ($impl)(&*MODULE); + } + )+ + } + }; +} + +#[macro_export] +macro_rules! cross_backend_test_suite { + ( + mod $modname:ident, + backend_ref = $backend_ref:ty, + backend_test = $backend_test:ty, + size = $size:expr, + basek = $basek:expr, + tests = { + $( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)? + } + ) => { + mod $modname { + use poulpy_hal::{api::ModuleNew, layouts::Module}; + + use once_cell::sync::Lazy; + + static MODULE_REF: Lazy> = + Lazy::new(|| Module::<$backend_ref>::new($size)); + static MODULE_TEST: Lazy> = + Lazy::new(|| Module::<$backend_test>::new($size)); + + $( + $(#[$attr])* + #[test] + fn $test_name() { + ($impl)($basek, &*MODULE_REF, &*MODULE_TEST); + } + )+ + } + }; +} diff --git a/poulpy-hal/src/tests/serialization.rs b/poulpy-hal/src/test_suite/serialization.rs similarity index 97% rename from poulpy-hal/src/tests/serialization.rs rename to poulpy-hal/src/test_suite/serialization.rs index 17300bb..f2f9ee7 100644 --- a/poulpy-hal/src/tests/serialization.rs +++ b/poulpy-hal/src/test_suite/serialization.rs @@ -14,7 +14,7 @@ where { // Fill original with uniform random data let mut source = Source::new([0u8; 32]); - original.fill_uniform(&mut source); + original.fill_uniform(50, &mut source); // Serialize into a buffer let mut buffer = Vec::new(); diff --git a/poulpy-hal/src/test_suite/svp.rs b/poulpy-hal/src/test_suite/svp.rs new file mode 100644 index 0000000..7f7fbac --- /dev/null +++ b/poulpy-hal/src/test_suite/svp.rs @@ -0,0 +1,470 @@ +use rand::RngCore; + +use crate::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDft, SvpApplyDftToDft, SvpApplyDftToDftAdd, SvpApplyDftToDftInplace, + SvpPPolAlloc, SvpPrepare, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftApply, + VecZnxIdftApplyConsume, + }, + layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxDft}, + source::Source, +}; + +pub fn test_svp_apply_dft(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: SvpPrepare
+ + SvpApplyDft
+ + SvpPPolAlloc
+ + VecZnxDftAlloc
+ + VecZnxBigNormalize
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalizeTmpBytes, + Module: SvpPrepare + + SvpApplyDft + + SvpPPolAlloc + + VecZnxDftAlloc + + VecZnxBigNormalize + + VecZnxIdftApplyConsume + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let cols: usize = 2; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + let mut scalar: ScalarZnx> = ScalarZnx::alloc(n, cols); + scalar.fill_uniform(basek, &mut source); + + let scalar_digest: u64 = scalar.digest_u64(); + + let mut svp_ref: SvpPPol, BR> = module_ref.svp_ppol_alloc(cols); + let mut svp_test: SvpPPol, BT> = module_test.svp_ppol_alloc(cols); + + for j in 0..cols { + module_ref.svp_prepare(&mut svp_ref, j, &scalar, j); + module_test.svp_prepare(&mut svp_test, j, &scalar, j); + } + + assert_eq!(scalar.digest_u64(), scalar_digest); + + let svp_ref_digest: u64 = svp_ref.digest_u64(); + let svp_test_digest: u64 = svp_test.digest_u64(); + + for a_size in [1, 2, 3, 4] { + // Create a random input VecZnx + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + // Allocate VecZnxDft from FFT64Ref and module to test + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + // Fill output with garbage + source.fill_bytes(res_dft_ref.data_mut()); + source.fill_bytes(res_dft_test.data_mut()); + + for j in 0..cols { + module_ref.svp_apply_dft(&mut res_dft_ref, j, &svp_ref, j, &a, j); + module_test.svp_apply_dft(&mut res_dft_test, j, &svp_test, j, &a, j); + } + + // Assert no change to inputs + assert_eq!(svp_ref.digest_u64(), svp_ref_digest); + assert_eq!(svp_test.digest_u64(), svp_test_digest); + assert_eq!(a.digest_u64(), a_digest); + + let res_big_ref: crate::layouts::VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: crate::layouts::VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_svp_apply_dft_to_dft(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: SvpPrepare
+ + SvpApplyDftToDft
+ + SvpPPolAlloc
+ + VecZnxDftAlloc
+ + VecZnxBigNormalize
+ + VecZnxDftApply
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalizeTmpBytes, + Module: SvpPrepare + + SvpApplyDftToDft + + SvpPPolAlloc + + VecZnxDftAlloc + + VecZnxBigNormalize + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let cols: usize = 2; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + let mut scalar: ScalarZnx> = ScalarZnx::alloc(n, cols); + scalar.fill_uniform(basek, &mut source); + + let scalar_digest: u64 = scalar.digest_u64(); + + let mut svp_ref: SvpPPol, BR> = module_ref.svp_ppol_alloc(cols); + let mut svp_test: SvpPPol, BT> = module_test.svp_ppol_alloc(cols); + + for j in 0..cols { + module_ref.svp_prepare(&mut svp_ref, j, &scalar, j); + module_test.svp_prepare(&mut svp_test, j, &scalar, j); + } + + assert_eq!(scalar.digest_u64(), scalar_digest); + + let svp_ref_digest: u64 = svp_ref.digest_u64(); + let svp_test_digest: u64 = svp_test.digest_u64(); + + for a_size in [3] { + // Create a random input VecZnx + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let a_digest: u64 = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let a_dft_ref_digest: u64 = a_dft_ref.digest_u64(); + let a_dft_test_digest: u64 = a_dft_test.digest_u64(); + + for res_size in [3] { + // Allocate VecZnxDft from FFT64Ref and module to test + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + // Fill output with garbage + source.fill_bytes(res_dft_ref.data_mut()); + source.fill_bytes(res_dft_test.data_mut()); + + for j in 0..cols { + module_ref.svp_apply_dft_to_dft(&mut res_dft_ref, j, &svp_ref, j, &a_dft_ref, j); + module_test.svp_apply_dft_to_dft(&mut res_dft_test, j, &svp_test, j, &a_dft_test, j); + } + + // Assert no change to inputs + assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); + assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest); + assert_eq!(svp_ref.digest_u64(), svp_ref_digest); + assert_eq!(svp_test.digest_u64(), svp_test_digest); + assert_eq!(a.digest_u64(), a_digest); + + let res_big_ref: crate::layouts::VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: crate::layouts::VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + println!("res_big_ref: {}", res_big_ref); + println!("res_big_test: {}", res_big_test); + + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_svp_apply_dft_to_dft_add(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: SvpPrepare
+ + SvpApplyDftToDftAdd
+ + SvpPPolAlloc
+ + VecZnxDftAlloc
+ + VecZnxBigNormalize
+ + VecZnxDftApply
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalizeTmpBytes, + Module: SvpPrepare + + SvpApplyDftToDftAdd + + SvpPPolAlloc + + VecZnxDftAlloc + + VecZnxBigNormalize + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let cols: usize = 2; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + let mut scalar: ScalarZnx> = ScalarZnx::alloc(n, cols); + scalar.fill_uniform(basek, &mut source); + + let scalar_digest: u64 = scalar.digest_u64(); + + let mut svp_ref: SvpPPol, BR> = module_ref.svp_ppol_alloc(cols); + let mut svp_test: SvpPPol, BT> = module_test.svp_ppol_alloc(cols); + + for j in 0..cols { + module_ref.svp_prepare(&mut svp_ref, j, &scalar, j); + module_test.svp_prepare(&mut svp_test, j, &scalar, j); + } + + assert_eq!(scalar.digest_u64(), scalar_digest); + + let svp_ref_digest: u64 = svp_ref.digest_u64(); + let svp_test_digest: u64 = svp_test.digest_u64(); + + for a_size in [1, 2, 3, 4] { + // Create a random input VecZnx + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let a_digest: u64 = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + let a_dft_ref_digest: u64 = a_dft_ref.digest_u64(); + let a_dft_test_digest: u64 = a_dft_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j); + module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j); + } + + // Fill output with garbage + source.fill_bytes(res_dft_ref.data_mut()); + source.fill_bytes(res_dft_test.data_mut()); + + for j in 0..cols { + module_ref.svp_apply_dft_to_dft_add(&mut res_dft_ref, j, &svp_ref, j, &a_dft_ref, j); + module_test.svp_apply_dft_to_dft_add(&mut res_dft_test, j, &svp_test, j, &a_dft_test, j); + } + + // Assert no change to inputs + assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); + assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest); + assert_eq!(svp_ref.digest_u64(), svp_ref_digest); + assert_eq!(svp_test.digest_u64(), svp_test_digest); + assert_eq!(a.digest_u64(), a_digest); + + let res_big_ref: crate::layouts::VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: crate::layouts::VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_svp_apply_dft_to_dft_inplace( + basek: usize, + module_ref: &Module
, + module_test: &Module, +) where + Module
: SvpPrepare
+ + SvpApplyDftToDftInplace
+ + SvpPPolAlloc
+ + VecZnxDftAlloc
+ + VecZnxBigNormalize
+ + VecZnxDftApply
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalizeTmpBytes, + Module: SvpPrepare + + SvpApplyDftToDftInplace + + SvpPPolAlloc + + VecZnxDftAlloc + + VecZnxBigNormalize + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let cols: usize = 2; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + let mut scalar: ScalarZnx> = ScalarZnx::alloc(n, cols); + scalar.fill_uniform(basek, &mut source); + + let scalar_digest: u64 = scalar.digest_u64(); + + let mut svp_ref: SvpPPol, BR> = module_ref.svp_ppol_alloc(cols); + let mut svp_test: SvpPPol, BT> = module_test.svp_ppol_alloc(cols); + + for j in 0..cols { + module_ref.svp_prepare(&mut svp_ref, j, &scalar, j); + module_test.svp_prepare(&mut svp_test, j, &scalar, j); + } + + assert_eq!(scalar.digest_u64(), scalar_digest); + + let svp_ref_digest: u64 = svp_ref.digest_u64(); + let svp_test_digest: u64 = svp_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + let res_digest: u64 = res.digest_u64(); + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j); + module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j); + } + + assert_eq!(res.digest_u64(), res_digest); + + for j in 0..cols { + module_ref.svp_apply_dft_to_dft_inplace(&mut res_dft_ref, j, &svp_ref, j); + module_test.svp_apply_dft_to_dft_inplace(&mut res_dft_test, j, &svp_test, j); + } + + // Assert no change to inputs + assert_eq!(svp_ref.digest_u64(), svp_ref_digest); + assert_eq!(svp_test.digest_u64(), svp_test_digest); + + let res_big_ref: crate::layouts::VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: crate::layouts::VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + println!("res_ref: {}", res_ref); + println!("res_test: {}", res_test); + + assert_eq!(res_ref, res_test); + } +} diff --git a/poulpy-hal/src/test_suite/vec_znx.rs b/poulpy-hal/src/test_suite/vec_znx.rs new file mode 100644 index 0000000..f5d180e --- /dev/null +++ b/poulpy-hal/src/test_suite/vec_znx.rs @@ -0,0 +1,1255 @@ +use std::f64::consts::SQRT_2; + +use crate::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAdd, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalar, + VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxAutomorphismInplaceTmpBytes, VecZnxCopy, + VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh, VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings, + VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, + VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, + VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes, + VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, + }, + layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut}, + reference::znx::znx_copy_ref, + source::Source, +}; + +pub fn test_vec_znx_encode_vec_i64_lo_norm() { + let n: usize = 32; + let basek: usize = 17; + let size: usize = 5; + let k: usize = size * basek - 5; + let mut a: VecZnx> = VecZnx::alloc(n, 2, size); + let mut source: Source = Source::new([0u8; 32]); + let raw: &mut [i64] = a.raw_mut(); + raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + (0..a.cols()).for_each(|col_i| { + let mut have: Vec = vec![i64::default(); n]; + have.iter_mut() + .for_each(|x| *x = (source.next_i64() << 56) >> 56); + a.encode_vec_i64(basek, col_i, k, &have, 10); + let mut want: Vec = vec![i64::default(); n]; + a.decode_vec_i64(basek, col_i, k, &mut want); + assert_eq!(have, want, "{:?} != {:?}", &have, &want); + }); +} + +pub fn test_vec_znx_encode_vec_i64_hi_norm() { + let n: usize = 32; + let basek: usize = 17; + let size: usize = 5; + for k in [1, basek / 2, size * basek - 5] { + let mut a: VecZnx> = VecZnx::alloc(n, 2, size); + let mut source = Source::new([0u8; 32]); + let raw: &mut [i64] = a.raw_mut(); + raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + (0..a.cols()).for_each(|col_i| { + let mut have: Vec = vec![i64::default(); n]; + have.iter_mut().for_each(|x| { + if k < 64 { + *x = source.next_u64n(1 << k, (1 << k) - 1) as i64; + } else { + *x = source.next_i64(); + } + }); + a.encode_vec_i64(basek, col_i, k, &have, 63); + let mut want: Vec = vec![i64::default(); n]; + a.decode_vec_i64(basek, col_i, k, &mut want); + assert_eq!(have, want, "{:?} != {:?}", &have, &want); + }) + } +} + +pub fn test_vec_znx_add_scalar(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxAddScalar, + Module: VecZnxAddScalar, +{ + assert_eq!(module_ref.n(), module_test.n()); + + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut a: ScalarZnx> = ScalarZnx::alloc(n, cols); + a.fill_uniform(basek, &mut source); + let a_digest = a.digest_u64(); + + for a_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, a_size); + b.fill_uniform(basek, &mut source); + let b_digest: u64 = b.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut rest_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + // Set d to garbage + rest_ref.fill_uniform(basek, &mut source); + res_test.fill_uniform(basek, &mut source); + + // Reference + for i in 0..cols { + module_ref.vec_znx_add_scalar(&mut rest_ref, i, &a, i, &b, i, (res_size.min(a_size)) - 1); + module_test.vec_znx_add_scalar(&mut res_test, i, &a, i, &b, i, (res_size.min(a_size)) - 1); + } + + assert_eq!(b.digest_u64(), b_digest); + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(rest_ref, res_test); + } + } +} + +pub fn test_vec_znx_add_scalar_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxAddScalarInplace, + Module: VecZnxAddScalarInplace, +{ + assert_eq!(module_ref.n(), module_test.n()); + + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut b: ScalarZnx> = ScalarZnx::alloc(n, cols); + b.fill_uniform(basek, &mut source); + let b_digest: u64 = b.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut rest_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + rest_ref.fill_uniform(basek, &mut source); + res_test.raw_mut().copy_from_slice(rest_ref.raw()); + + for i in 0..cols { + module_ref.vec_znx_add_scalar_inplace(&mut rest_ref, i, res_size - 1, &b, i); + module_test.vec_znx_add_scalar_inplace(&mut res_test, i, res_size - 1, &b, i); + } + + assert_eq!(b.digest_u64(), b_digest); + assert_eq!(rest_ref, res_test); + } +} +pub fn test_vec_znx_add(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxAdd, + Module: VecZnxAdd, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let a_digest: u64 = a.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + + let b_digest: u64 = b.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + // Set d to garbage + res_ref.fill_uniform(basek, &mut source); + res_test.fill_uniform(basek, &mut source); + + // Reference + for i in 0..cols { + module_test.vec_znx_add(&mut res_ref, i, &a, i, &b, i); + module_ref.vec_znx_add(&mut res_test, i, &a, i, &b, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(b.digest_u64(), b_digest); + + assert_eq!(res_ref, res_test); + } + } + } +} + +pub fn test_vec_znx_add_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxAddInplace, + Module: VecZnxAddInplace, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.raw_mut().copy_from_slice(res_ref.raw()); + + for i in 0..cols { + module_ref.vec_znx_add_inplace(&mut res_ref, i, &a, i); + module_test.vec_znx_add_inplace(&mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_automorphism(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxAutomorphism, + Module: VecZnxAutomorphism, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let p: i64 = -5; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_automorphism(p, &mut res_ref, i, &a, i); + module_test.vec_znx_automorphism(p, &mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + + let p: i64 = 5; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_automorphism(p, &mut res_ref, i, &a, i); + module_test.vec_znx_automorphism(p, &mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_automorphism_inplace( + basek: usize, + module_ref: &Module
, + module_test: &Module, +) where + Module
: VecZnxAutomorphismInplace
+ VecZnxAutomorphismInplaceTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxAutomorphismInplace + VecZnxAutomorphismInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_automorphism_inplace_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_automorphism_inplace_tmp_bytes()); + + for size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + res_ref.fill_uniform(basek, &mut source); + znx_copy_ref(res_test.raw_mut(), res_ref.raw()); + + let p: i64 = -7; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_automorphism_inplace(p, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_automorphism_inplace(p, &mut res_test, i, scratch_test.borrow()); + } + + assert_eq!(res_ref, res_test); + + let p: i64 = 7; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_automorphism_inplace(p, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_automorphism_inplace(p, &mut res_test, i, scratch_test.borrow()); + } + + assert_eq!(res_ref, res_test); + } +} + +pub fn test_vec_znx_copy(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxCopy, + Module: VecZnxCopy, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_0: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_1: VecZnx> = VecZnx::alloc(n, cols, res_size); + + // Set d to garbage + res_0.fill_uniform(basek, &mut source); + res_1.fill_uniform(basek, &mut source); + + // Reference + for i in 0..cols { + module_ref.vec_znx_copy(&mut res_0, i, &a, i); + module_ref.vec_znx_copy(&mut res_1, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_0, res_1); + } + } +} + +pub fn test_vec_znx_merge_rings(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxMergeRings
+ ModuleNew
+ VecZnxMergeRingsTmpBytes, + Module: VecZnxMergeRings + ModuleNew + VecZnxMergeRingsTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_merge_rings_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_merge_rings_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: [VecZnx>; 2] = [ + VecZnx::alloc(n >> 1, cols, a_size), + VecZnx::alloc(n >> 1, cols, a_size), + ]; + + a.iter_mut().for_each(|ai| { + ai.fill_uniform(basek, &mut source); + }); + + let a_digests: [u64; 2] = [a[0].digest_u64(), a[1].digest_u64()]; + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.fill_uniform(basek, &mut source); + + for i in 0..cols { + module_ref.vec_znx_merge_rings(&mut res_test, i, &a, i, scratch_ref.borrow()); + module_test.vec_znx_merge_rings(&mut res_ref, i, &a, i, scratch_test.borrow()); + } + + assert_eq!([a[0].digest_u64(), a[1].digest_u64()], a_digests); + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_mul_xp_minus_one(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxMulXpMinusOne, + Module: VecZnxMulXpMinusOne, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let p: i64 = -5; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_mul_xp_minus_one(p, &mut res_ref, i, &a, i); + module_test.vec_znx_mul_xp_minus_one(p, &mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_test, res_ref); + + let p: i64 = 5; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_mul_xp_minus_one(p, &mut res_ref, i, &a, i); + module_test.vec_znx_mul_xp_minus_one(p, &mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_test, res_ref); + } + } +} + +pub fn test_vec_znx_mul_xp_minus_one_inplace( + basek: usize, + module_ref: &Module
, + module_test: &Module, +) where + Module
: VecZnxMulXpMinusOneInplace
+ VecZnxMulXpMinusOneInplaceTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxMulXpMinusOneInplace + VecZnxMulXpMinusOneInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_mul_xp_minus_one_inplace_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_mul_xp_minus_one_inplace_tmp_bytes()); + + for size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + res_ref.fill_uniform(basek, &mut source); + znx_copy_ref(res_test.raw_mut(), res_ref.raw()); + + let p: i64 = -7; + + for i in 0..cols { + module_ref.vec_znx_mul_xp_minus_one_inplace(p, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_mul_xp_minus_one_inplace(p, &mut res_test, i, scratch_test.borrow()); + } + + assert_eq!(res_ref, res_test); + + let p: i64 = 7; + + for i in 0..cols { + module_ref.vec_znx_mul_xp_minus_one_inplace(p, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_mul_xp_minus_one_inplace(p, &mut res_test, i, scratch_test.borrow()); + } + + assert_eq!(res_ref, res_test); + } +} + +pub fn test_vec_znx_negate(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxNegate, + Module: VecZnxNegate, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.raw_mut().copy_from_slice(res_ref.raw()); + + for i in 0..cols { + module_ref.vec_znx_negate(&mut res_ref, i, &a, i); + module_test.vec_znx_negate(&mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_negate_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxNegateInplace, + Module: VecZnxNegateInplace, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.raw_mut().copy_from_slice(res_ref.raw()); + + for i in 0..cols { + module_ref.vec_znx_negate_inplace(&mut res_ref, i); + module_test.vec_znx_negate_inplace(&mut res_test, i); + } + + assert_eq!(res_ref, res_test); + } +} + +pub fn test_vec_znx_normalize(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxNormalize
+ VecZnxNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxNormalize + VecZnxNormalizeTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + // Set d to garbage + res_ref.fill_uniform(basek, &mut source); + res_test.fill_uniform(basek, &mut source); + + // Reference + for i in 0..cols { + module_ref.vec_znx_normalize(basek, &mut res_ref, i, &a, i, scratch_ref.borrow()); + module_test.vec_znx_normalize(basek, &mut res_test, i, &a, i, scratch_test.borrow()); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_normalize_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxNormalizeInplace
+ VecZnxNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_normalize_tmp_bytes()); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.raw_mut().copy_from_slice(res_ref.raw()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_normalize_inplace(basek, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_normalize_inplace(basek, &mut res_test, i, scratch_test.borrow()); + } + + assert_eq!(res_ref, res_test); + } +} + +pub fn test_vec_znx_rotate(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxRotate, + Module: VecZnxRotate, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let p: i64 = -5; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_rotate(p, &mut res_ref, i, &a, i); + module_test.vec_znx_rotate(p, &mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + + let p: i64 = 5; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_rotate(p, &mut res_ref, i, &a, i); + module_test.vec_znx_rotate(p, &mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_rotate_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxRotateInplace
+ VecZnxRotateInplaceTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxRotateInplace + VecZnxRotateInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_rotate_inplace_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_rotate_inplace_tmp_bytes()); + + for size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, size); + + // Fill a with random i64 + res_ref.fill_uniform(basek, &mut source); + znx_copy_ref(res_test.raw_mut(), res_ref.raw()); + + let p: i64 = -5; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_rotate_inplace(p, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_rotate_inplace(p, &mut res_test, i, scratch_test.borrow()); + } + + assert_eq!(res_ref, res_test); + + let p: i64 = 5; + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_rotate_inplace(p, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_rotate_inplace(p, &mut res_test, i, scratch_test.borrow()); + } + + assert_eq!(res_ref, res_test); + } +} + +pub fn test_vec_znx_fill_uniform(module: &Module) +where + Module: VecZnxFillUniform, +{ + let n: usize = module.n(); + let basek: usize = 17; + let size: usize = 5; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let one_12_sqrt: f64 = 0.28867513459481287; + (0..cols).for_each(|col_i| { + let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); + module.vec_znx_fill_uniform(basek, &mut a, col_i, &mut source); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..size).for_each(|limb_i| { + assert_eq!(a.at(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(basek, col_i); + assert!( + (std - one_12_sqrt).abs() < 0.01, + "std={} ~!= {}", + std, + one_12_sqrt + ); + } + }) + }); +} + +pub fn test_vec_znx_fill_normal(module: &Module) +where + Module: VecZnxFillNormal, +{ + let n: usize = module.n(); + let basek: usize = 17; + let k: usize = 2 * 17; + let size: usize = 5; + let sigma: f64 = 3.2; + let bound: f64 = 6.0 * sigma; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let k_f64: f64 = (1u64 << k as u64) as f64; + (0..cols).for_each(|col_i| { + let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); + module.vec_znx_fill_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..size).for_each(|limb_i| { + assert_eq!(a.at(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(basek, col_i) * k_f64; + assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma); + } + }) + }); +} + +pub fn test_vec_znx_add_normal(module: &Module) +where + Module: VecZnxFillNormal + VecZnxAddNormal, +{ + let n: usize = module.n(); + let basek: usize = 17; + let k: usize = 2 * 17; + let size: usize = 5; + let sigma: f64 = 3.2; + let bound: f64 = 6.0 * sigma; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let k_f64: f64 = (1u64 << k as u64) as f64; + let sqrt2: f64 = SQRT_2; + (0..cols).for_each(|col_i| { + let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); + module.vec_znx_fill_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); + module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..size).for_each(|limb_i| { + assert_eq!(a.at(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(basek, col_i) * k_f64; + assert!( + (std - sigma * sqrt2).abs() < 0.1, + "std={} ~!= {}", + std, + sigma * sqrt2 + ); + } + }) + }); +} + +pub fn test_vec_znx_lsh(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxLsh
+ VecZnxLshTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxLsh + VecZnxLshTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + for k in 0..res_size * basek { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + // Set d to garbage + res_ref.fill_uniform(basek, &mut source); + res_test.fill_uniform(basek, &mut source); + + // Reference + for i in 0..cols { + module_ref.vec_znx_lsh(basek, k, &mut res_ref, i, &a, i, scratch_ref.borrow()); + module_test.vec_znx_lsh(basek, k, &mut res_test, i, &a, i, scratch_test.borrow()); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } + } +} + +pub fn test_vec_znx_lsh_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxLshInplace
+ VecZnxLshTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxLshInplace + VecZnxLshTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes()); + + for res_size in [1, 2, 3, 4] { + for k in 0..basek * res_size { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.raw_mut().copy_from_slice(res_ref.raw()); + + for i in 0..cols { + module_ref.vec_znx_lsh_inplace(basek, k, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_lsh_inplace(basek, k, &mut res_test, i, scratch_test.borrow()); + } + + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_rsh(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxRsh
+ VecZnxLshTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxRsh + VecZnxLshTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + for k in 0..res_size * basek { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + // Set d to garbage + res_ref.fill_uniform(basek, &mut source); + res_test.fill_uniform(basek, &mut source); + + // Reference + for i in 0..cols { + module_ref.vec_znx_rsh(basek, k, &mut res_ref, i, &a, i, scratch_ref.borrow()); + module_test.vec_znx_rsh(basek, k, &mut res_test, i, &a, i, scratch_test.borrow()); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } + } +} + +pub fn test_vec_znx_rsh_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxRshInplace
+ VecZnxLshTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxRshInplace + VecZnxLshTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes()); + + for res_size in [1, 2, 3, 4] { + for k in 0..basek * res_size { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.raw_mut().copy_from_slice(res_ref.raw()); + + for i in 0..cols { + module_ref.vec_znx_rsh_inplace(basek, k, &mut res_ref, i, scratch_ref.borrow()); + module_test.vec_znx_rsh_inplace(basek, k, &mut res_test, i, scratch_test.borrow()); + } + + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_split_ring(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxSplitRing
+ ModuleNew
+ VecZnxSplitRingTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: VecZnxSplitRing + ModuleNew + VecZnxSplitRingTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_split_ring_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_split_ring_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: [VecZnx>; 2] = [ + VecZnx::alloc(n >> 1, cols, res_size), + VecZnx::alloc(n >> 1, cols, res_size), + ]; + + let mut res_test: [VecZnx>; 2] = [ + VecZnx::alloc(n >> 1, cols, res_size), + VecZnx::alloc(n >> 1, cols, res_size), + ]; + + res_ref.iter_mut().for_each(|ri| { + ri.fill_uniform(basek, &mut source); + }); + + res_test.iter_mut().for_each(|ri| { + ri.fill_uniform(basek, &mut source); + }); + + for i in 0..cols { + module_ref.vec_znx_split_ring(&mut res_ref, i, &a, i, scratch_ref.borrow()); + module_test.vec_znx_split_ring(&mut res_test, i, &a, i, scratch_test.borrow()); + } + + assert_eq!(a.digest_u64(), a_digest); + + for (a, b) in res_ref.iter().zip(res_test.iter()) { + assert_eq!(a, b); + } + } + } +} + +pub fn test_vec_znx_sub_scalar(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxSubScalar, + Module: VecZnxSubScalar, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut a: ScalarZnx> = ScalarZnx::alloc(n, cols); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + let b_digest: u64 = b.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_0: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_1: VecZnx> = VecZnx::alloc(n, cols, res_size); + + // Set d to garbage + res_0.fill_uniform(basek, &mut source); + res_1.fill_uniform(basek, &mut source); + + // Reference + for i in 0..cols { + module_ref.vec_znx_sub_scalar(&mut res_0, i, &a, i, &b, i, (res_size.min(b_size)) - 1); + module_test.vec_znx_sub_scalar(&mut res_1, i, &a, i, &b, i, (res_size.min(b_size)) - 1); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(b.digest_u64(), b_digest); + assert_eq!(res_0, res_1); + } + } +} + +pub fn test_vec_znx_sub_scalar_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxSubScalarInplace, + Module: VecZnxSubScalarInplace, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut a: ScalarZnx> = ScalarZnx::alloc(n, cols); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_0: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_1: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_0.fill_uniform(basek, &mut source); + res_1.raw_mut().copy_from_slice(res_0.raw()); + + for i in 0..cols { + module_ref.vec_znx_sub_scalar_inplace(&mut res_0, i, res_size - 1, &a, i); + module_test.vec_znx_sub_scalar_inplace(&mut res_1, i, res_size - 1, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_0, res_1); + } +} + +pub fn test_vec_znx_sub(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxSub, + Module: VecZnxSub, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + let b_digest: u64 = b.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + // Set d to garbage + res_ref.fill_uniform(basek, &mut source); + res_test.fill_uniform(basek, &mut source); + + // Reference + for i in 0..cols { + module_test.vec_znx_sub(&mut res_ref, i, &a, i, &b, i); + module_ref.vec_znx_sub(&mut res_test, i, &a, i, &b, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(b.digest_u64(), b_digest); + + assert_eq!(res_ref, res_test); + } + } + } +} + +pub fn test_vec_znx_sub_ab_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxSubABInplace, + Module: VecZnxSubABInplace, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.raw_mut().copy_from_slice(res_ref.raw()); + + for i in 0..cols { + module_test.vec_znx_sub_ab_inplace(&mut res_ref, i, &a, i); + module_ref.vec_znx_sub_ab_inplace(&mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_sub_ba_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxSubBAInplace, + Module: VecZnxSubBAInplace, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.raw_mut().copy_from_slice(res_ref.raw()); + + for i in 0..cols { + module_test.vec_znx_sub_ba_inplace(&mut res_ref, i, &a, i); + module_ref.vec_znx_sub_ba_inplace(&mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_switch_ring(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxSwitchRing, + Module: VecZnxSwitchRing, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + + // Fill a with random i64 + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + { + let mut res_ref: VecZnx> = VecZnx::alloc(n << 1, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n << 1, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.fill_uniform(basek, &mut source); + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_switch_ring(&mut res_ref, i, &a, i); + module_test.vec_znx_switch_ring(&mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + + { + let mut res_ref: VecZnx> = VecZnx::alloc(n >> 1, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n >> 1, cols, res_size); + + res_ref.fill_uniform(basek, &mut source); + res_test.fill_uniform(basek, &mut source); + + // Normalize on c + for i in 0..cols { + module_ref.vec_znx_switch_ring(&mut res_ref, i, &a, i); + module_test.vec_znx_switch_ring(&mut res_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + assert_eq!(res_ref, res_test); + } + } + } +} diff --git a/poulpy-hal/src/test_suite/vec_znx_big.rs b/poulpy-hal/src/test_suite/vec_znx_big.rs new file mode 100644 index 0000000..d888403 --- /dev/null +++ b/poulpy-hal/src/test_suite/vec_znx_big.rs @@ -0,0 +1,1432 @@ +use rand::RngCore; + +use crate::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, + VecZnxBigAlloc, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes, + VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace, + VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace, + }, + layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig}, + source::Source, +}; + +pub fn test_vec_znx_big_add(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: + VecZnxBigAdd
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, + Module: + VecZnxBigAdd + VecZnxBigAlloc + VecZnxBigFromSmall + VecZnxBigNormalize + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest = a.digest_u64(); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + let b_digest = b.digest_u64(); + + let mut b_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, b_size); + let mut b_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, b_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut b_ref, j, &b, j); + module_test.vec_znx_big_from_small(&mut b_test, j, &b, j); + } + + assert_eq!(b.digest_u64(), b_digest); + + let b_ref_digest: u64 = b_ref.digest_u64(); + let b_test_digest: u64 = b_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + // Set res to garbage + source.fill_bytes(res_big_ref.data_mut()); + source.fill_bytes(res_big_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_big_add(&mut res_big_ref, i, &a_ref, i, &b_ref, i); + module_test.vec_znx_big_add(&mut res_big_test, i, &a_test, i, &b_test, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + assert_eq!(b_ref.digest_u64(), b_ref_digest); + assert_eq!(b_test.digest_u64(), b_test_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_big_add_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxBigAddInplace
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigAddInplace + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); + module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + } + + for i in 0..cols { + module_ref.vec_znx_big_add_inplace(&mut res_big_ref, i, &a_ref, i); + module_test.vec_znx_big_add_inplace(&mut res_big_test, i, &a_test, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} + +pub fn test_vec_znx_big_add_small(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: + VecZnxBigAddSmall
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, + Module: + VecZnxBigAddSmall + VecZnxBigAlloc + VecZnxBigFromSmall + VecZnxBigNormalize + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + let b_digest: u64 = b.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + // Set res to garbage + source.fill_bytes(res_big_ref.data_mut()); + source.fill_bytes(res_big_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_big_add_small(&mut res_big_ref, i, &a_ref, i, &b, i); + module_test.vec_znx_big_add_small(&mut res_big_test, i, &a_test, i, &b, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + assert_eq!(b.digest_u64(), b_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_big_add_small_inplace( + basek: usize, + module_ref: &Module
, + module_test: &Module, +) where + Module
: VecZnxBigAddSmallInplace
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigAddSmallInplace + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); + module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + } + + for i in 0..cols { + module_ref.vec_znx_big_add_small_inplace(&mut res_big_ref, i, &a, i); + module_test.vec_znx_big_add_small_inplace(&mut res_big_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} + +pub fn test_vec_znx_big_automorphism(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxBigAutomorphism
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigAutomorphism + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + for p in [-5, 5] { + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + // Set res to garbage + source.fill_bytes(res_big_ref.data_mut()); + source.fill_bytes(res_big_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_big_automorphism(p, &mut res_big_ref, i, &a_ref, i); + module_test.vec_znx_big_automorphism(p, &mut res_big_test, i, &a_test, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_big_automorphism_inplace( + basek: usize, + module_ref: &Module
, + module_test: &Module, +) where + Module
: VecZnxBigAutomorphismInplace
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigAutomorphismInplaceTmpBytes + + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigAutomorphismInplace + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigAutomorphismInplaceTmpBytes + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc( + module_ref.vec_znx_big_automorphism_inplace_tmp_bytes() | module_ref.vec_znx_big_normalize_tmp_bytes(), + ); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc( + module_test.vec_znx_big_automorphism_inplace_tmp_bytes() | module_test.vec_znx_big_normalize_tmp_bytes(), + ); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for p in [-5, 5] { + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); + module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + } + + for i in 0..cols { + module_ref.vec_znx_big_automorphism_inplace(p, &mut res_big_ref, i, scratch_ref.borrow()); + module_test.vec_znx_big_automorphism_inplace(p, &mut res_big_test, i, scratch_test.borrow()); + } + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} + +pub fn test_vec_znx_big_negate(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: + VecZnxBigNegate
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, + Module: + VecZnxBigNegate + VecZnxBigAlloc + VecZnxBigFromSmall + VecZnxBigNormalize + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + // Set res to garbage + source.fill_bytes(res_big_ref.data_mut()); + source.fill_bytes(res_big_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_big_negate(&mut res_big_ref, i, &a_ref, i); + module_test.vec_znx_big_negate(&mut res_big_test, i, &a_test, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} + +pub fn test_vec_znx_big_negate_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxBigNegateInplace
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigAutomorphismInplaceTmpBytes + + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigNegateInplace + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigAutomorphismInplaceTmpBytes + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); + module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + } + + for i in 0..cols { + module_ref.vec_znx_big_negate_inplace(&mut res_big_ref, i); + module_test.vec_znx_big_negate_inplace(&mut res_big_test, i); + } + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } +} + +pub fn test_vec_znx_big_normalize(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigAutomorphismInplaceTmpBytes + + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigAutomorphismInplaceTmpBytes + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc( + module_ref.vec_znx_big_automorphism_inplace_tmp_bytes() | module_ref.vec_znx_big_normalize_tmp_bytes(), + ); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc( + module_test.vec_znx_big_automorphism_inplace_tmp_bytes() | module_test.vec_znx_big_normalize_tmp_bytes(), + ); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(63, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + // Set d to garbage + source.fill_bytes(res_ref.data_mut()); + source.fill_bytes(res_test.data_mut()); + + // Reference + for j in 0..cols { + module_ref.vec_znx_big_normalize(basek, &mut res_ref, j, &a_ref, j, scratch_ref.borrow()); + module_test.vec_znx_big_normalize(basek, &mut res_test, j, &a_test, j, scratch_test.borrow()); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + + assert_eq!(res_ref, res_test); + } + } +} + +pub fn test_vec_znx_big_sub(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: + VecZnxBigSub
+ VecZnxBigAlloc
+ VecZnxBigFromSmall
+ VecZnxBigNormalize
+ VecZnxBigNormalizeTmpBytes, + Module: + VecZnxBigSub + VecZnxBigAlloc + VecZnxBigFromSmall + VecZnxBigNormalize + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + + let mut b_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, b_size); + let mut b_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, b_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut b_ref, j, &b, j); + module_test.vec_znx_big_from_small(&mut b_test, j, &b, j); + } + + let b_ref_digest: u64 = b_ref.digest_u64(); + let b_test_digest: u64 = b_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + // Set res to garbage + source.fill_bytes(res_big_ref.data_mut()); + source.fill_bytes(res_big_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_big_sub(&mut res_big_ref, i, &a_ref, i, &b_ref, i); + module_test.vec_znx_big_sub(&mut res_big_test, i, &a_test, i, &b_test, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + assert_eq!(b_ref.digest_u64(), b_ref_digest); + assert_eq!(b_test.digest_u64(), b_test_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_big_sub_ab_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxBigSubABInplace
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigSubABInplace + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); + module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + } + + for i in 0..cols { + module_ref.vec_znx_big_sub_ab_inplace(&mut res_big_ref, i, &a_ref, i); + module_test.vec_znx_big_sub_ab_inplace(&mut res_big_test, i, &a_test, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} + +pub fn test_vec_znx_big_sub_ba_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxBigSubBAInplace
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigSubBAInplace + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); + module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + } + + for i in 0..cols { + module_ref.vec_znx_big_sub_ba_inplace(&mut res_big_ref, i, &a_ref, i); + module_test.vec_znx_big_sub_ba_inplace(&mut res_big_test, i, &a_test, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} + +pub fn test_vec_znx_big_sub_small_a(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxBigSubSmallA
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigSubSmallA + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + let b_digest: u64 = b.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + // Set res to garbage + source.fill_bytes(res_big_ref.data_mut()); + source.fill_bytes(res_big_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_big_sub_small_a(&mut res_big_ref, i, &b, i, &a_ref, i); + module_test.vec_znx_big_sub_small_a(&mut res_big_test, i, &b, i, &a_test, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + assert_eq!(b.digest_u64(), b_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_big_sub_small_b(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxBigSubSmallB
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigSubSmallB + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let mut a_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, a_size); + let mut a_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut a_ref, j, &a, j); + module_test.vec_znx_big_from_small(&mut a_test, j, &a, j); + } + + let a_ref_digest: u64 = a_ref.digest_u64(); + let a_test_digest: u64 = a_test.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + let b_digest: u64 = b.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + // Set res to garbage + source.fill_bytes(res_big_ref.data_mut()); + source.fill_bytes(res_big_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_big_sub_small_b(&mut res_big_ref, i, &a_ref, i, &b, i); + module_test.vec_znx_big_sub_small_b(&mut res_big_test, i, &a_test, i, &b, i); + } + + assert_eq!(a_ref.digest_u64(), a_ref_digest); + assert_eq!(a_test.digest_u64(), a_test_digest); + assert_eq!(b.digest_u64(), b_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_big_sub_small_a_inplace( + basek: usize, + module_ref: &Module
, + module_test: &Module, +) where + Module
: VecZnxBigSubSmallAInplace
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigSubSmallAInplace + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); + module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + } + + for i in 0..cols { + module_ref.vec_znx_big_sub_small_a_inplace(&mut res_big_ref, i, &a, i); + module_test.vec_znx_big_sub_small_a_inplace(&mut res_big_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} + +pub fn test_vec_znx_big_sub_small_b_inplace( + basek: usize, + module_ref: &Module
, + module_test: &Module, +) where + Module
: VecZnxBigSubSmallBInplace
+ + VecZnxBigAlloc
+ + VecZnxBigFromSmall
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxBigSubSmallBInplace + + VecZnxBigAlloc + + VecZnxBigFromSmall + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, res_size); + res.fill_uniform(basek, &mut source); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j); + module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j); + } + + for i in 0..cols { + module_ref.vec_znx_big_sub_small_b_inplace(&mut res_big_ref, i, &a, i); + module_test.vec_znx_big_sub_small_b_inplace(&mut res_big_test, i, &a, i); + } + + assert_eq!(a.digest_u64(), a_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} diff --git a/poulpy-hal/src/test_suite/vec_znx_dft.rs b/poulpy-hal/src/test_suite/vec_znx_dft.rs new file mode 100644 index 0000000..674f1d9 --- /dev/null +++ b/poulpy-hal/src/test_suite/vec_znx_dft.rs @@ -0,0 +1,930 @@ +use rand::RngCore; + +use crate::{ + api::{ + ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAdd, + VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftCopy, VecZnxDftSub, VecZnxDftSubABInplace, + VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes, + }, + layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft}, + source::Source, +}; + +pub fn test_vec_znx_dft_add(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxDftAdd
+ + VecZnxDftAlloc
+ + VecZnxDftApply
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxDftAdd + + VecZnxDftAlloc + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let a_dft_ref_digest: u64 = a_dft_ref.digest_u64(); + let a_dft_test_digest: u64 = a_dft_test.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + let b_digest: u64 = b.digest_u64(); + + let mut b_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, b_size); + let mut b_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, b_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut b_dft_ref, j, &b, j); + module_test.vec_znx_dft_apply(1, 0, &mut b_dft_test, j, &b, j); + } + + assert_eq!(b.digest_u64(), b_digest); + + let b_dft_ref_digest: u64 = b_dft_ref.digest_u64(); + let b_dft_test_digest: u64 = b_dft_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + // Set d to garbage + source.fill_bytes(res_dft_ref.data_mut()); + source.fill_bytes(res_dft_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_dft_add(&mut res_dft_ref, i, &a_dft_ref, i, &b_dft_ref, i); + module_test.vec_znx_dft_add(&mut res_dft_test, i, &a_dft_test, i, &b_dft_test, i); + } + + assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); + assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest); + assert_eq!(b_dft_ref.digest_u64(), b_dft_ref_digest); + assert_eq!(b_dft_test.digest_u64(), b_dft_test_digest); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_dft_add_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxDftAddInplace
+ + VecZnxDftAlloc
+ + VecZnxDftApply
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxDftAddInplace + + VecZnxDftAlloc + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let a_dft_ref_digest: u64 = a_dft_ref.digest_u64(); + let a_dft_test_digest: u64 = a_dft_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, a_size); + res.fill_uniform(basek, &mut source); + let res_digest: u64 = res.digest_u64(); + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j); + module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j); + } + + assert_eq!(res.digest_u64(), res_digest); + + // Reference + for i in 0..cols { + module_ref.vec_znx_dft_add_inplace(&mut res_dft_ref, i, &a_dft_ref, i); + module_test.vec_znx_dft_add_inplace(&mut res_dft_test, i, &a_dft_test, i); + } + + assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); + assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} + +pub fn test_vec_znx_copy(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxDftCopy
+ + VecZnxDftAlloc
+ + VecZnxDftApply
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxDftCopy + + VecZnxDftAlloc + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 6, 11] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let a_dft_ref_digest: u64 = a_dft_ref.digest_u64(); + let a_dft_test_digest: u64 = a_dft_test.digest_u64(); + + for res_size in [1, 2, 6, 11] { + for params in [[1, 0], [1, 1], [1, 2], [2, 2]] { + let steps: usize = params[0]; + let offset: usize = params[1]; + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + // Set d to garbage + source.fill_bytes(res_dft_ref.data_mut()); + source.fill_bytes(res_dft_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_dft_copy(steps, offset, &mut res_dft_ref, i, &a_dft_ref, i); + module_test.vec_znx_dft_copy(steps, offset, &mut res_dft_test, i, &a_dft_test, i); + } + + assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); + assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_idft_apply(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxDftApply
+ + VecZnxDftAlloc
+ + VecZnxBigAlloc
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApply
, + Module: VecZnxDftApply + + VecZnxDftAlloc + + VecZnxBigAlloc + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApply, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + for params in [[1, 0], [1, 1], [1, 2], [2, 2]] { + let steps: usize = params[0]; + let offset: usize = params[1]; + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + let res_dft_ref_digest: u64 = res_dft_ref.digest_u64(); + let rest_dft_test_digest: u64 = res_dft_test.digest_u64(); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_idft_apply(&mut res_big_ref, j, &res_dft_ref, j, scratch_ref.borrow()); + module_test.vec_znx_idft_apply( + &mut res_big_test, + j, + &res_dft_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_dft_ref.digest_u64(), res_dft_ref_digest); + assert_eq!(res_dft_test.digest_u64(), rest_dft_test_digest); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_idft_apply_tmpa(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxDftApply
+ + VecZnxDftAlloc
+ + VecZnxBigAlloc
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyTmpA
, + Module: VecZnxDftApply + + VecZnxDftAlloc + + VecZnxBigAlloc + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyTmpA, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + for params in [[1, 0], [1, 1], [1, 2], [2, 2]] { + let steps: usize = params[0]; + let offset: usize = params[1]; + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let mut res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_big_alloc(cols, res_size); + let mut res_big_test: VecZnxBig, BT> = module_test.vec_znx_big_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_idft_apply_tmpa(&mut res_big_ref, j, &mut res_dft_ref, j); + module_test.vec_znx_idft_apply_tmpa(&mut res_big_test, j, &mut res_dft_test, j); + } + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_idft_apply_consume(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxDftApply
+ + VecZnxIdftApplyTmpBytes + + VecZnxDftAlloc
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyConsume
, + Module: VecZnxDftApply + + VecZnxIdftApplyTmpBytes + + VecZnxDftAlloc + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes + + VecZnxIdftApplyConsume, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= + ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes() | module_ref.vec_znx_idft_apply_tmp_bytes()); + let mut scratch_test: ScratchOwned = + ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes() | module_test.vec_znx_idft_apply_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + for res_size in [1, 2, 3, 4] { + for params in [[1, 0], [1, 1], [1, 2], [2, 2]] { + let steps: usize = params[0]; + let offset: usize = params[1]; + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(steps, offset, &mut res_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(steps, offset, &mut res_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_dft_sub(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxDftSub
+ + VecZnxDftAlloc
+ + VecZnxDftApply
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxDftSub + + VecZnxDftAlloc + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let a_dft_ref_digest: u64 = a_dft_ref.digest_u64(); + let a_dft_test_digest: u64 = a_dft_test.digest_u64(); + + for b_size in [1, 2, 3, 4] { + let mut b: VecZnx> = VecZnx::alloc(n, cols, b_size); + b.fill_uniform(basek, &mut source); + let b_digest: u64 = b.digest_u64(); + + let mut b_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, b_size); + let mut b_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, b_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut b_dft_ref, j, &b, j); + module_test.vec_znx_dft_apply(1, 0, &mut b_dft_test, j, &b, j); + } + + assert_eq!(b.digest_u64(), b_digest); + + let b_dft_ref_digest: u64 = b_dft_ref.digest_u64(); + let b_dft_test_digest: u64 = b_dft_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + // Set d to garbage + source.fill_bytes(res_dft_ref.data_mut()); + source.fill_bytes(res_dft_test.data_mut()); + + // Reference + for i in 0..cols { + module_ref.vec_znx_dft_sub(&mut res_dft_ref, i, &a_dft_ref, i, &b_dft_ref, i); + module_test.vec_znx_dft_sub(&mut res_dft_test, i, &a_dft_test, i, &b_dft_test, i); + } + + assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); + assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest); + assert_eq!(b_dft_ref.digest_u64(), b_dft_ref_digest); + assert_eq!(b_dft_test.digest_u64(), b_dft_test_digest); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } +} + +pub fn test_vec_znx_dft_sub_ab_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxDftSubABInplace
+ + VecZnxDftAlloc
+ + VecZnxDftApply
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxDftSubABInplace + + VecZnxDftAlloc + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let a_dft_ref_digest: u64 = a_dft_ref.digest_u64(); + let a_dft_test_digest: u64 = a_dft_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, a_size); + res.fill_uniform(basek, &mut source); + let res_digest: u64 = res.digest_u64(); + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j); + module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j); + } + + assert_eq!(res.digest_u64(), res_digest); + + // Reference + for i in 0..cols { + module_ref.vec_znx_dft_sub_ab_inplace(&mut res_dft_ref, i, &a_dft_ref, i); + module_test.vec_znx_dft_sub_ab_inplace(&mut res_dft_test, i, &a_dft_test, i); + } + + assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); + assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} + +pub fn test_vec_znx_dft_sub_ba_inplace(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: VecZnxDftSubBAInplace
+ + VecZnxDftAlloc
+ + VecZnxDftApply
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalize
+ + VecZnxBigNormalizeTmpBytes, + Module: VecZnxDftSubBAInplace + + VecZnxDftAlloc + + VecZnxDftApply + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc(module_ref.vec_znx_big_normalize_tmp_bytes()); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes()); + for a_size in [1, 2, 3, 4] { + let mut a: VecZnx> = VecZnx::alloc(n, cols, a_size); + a.fill_uniform(basek, &mut source); + let a_digest = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, a_size); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, a_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let a_dft_ref_digest: u64 = a_dft_ref.digest_u64(); + let a_dft_test_digest: u64 = a_dft_test.digest_u64(); + + for res_size in [1, 2, 3, 4] { + let mut res: VecZnx> = VecZnx::alloc(n, cols, a_size); + res.fill_uniform(basek, &mut source); + let res_digest: u64 = res.digest_u64(); + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols, res_size); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols, res_size); + + for j in 0..cols { + module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j); + module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j); + } + + assert_eq!(res.digest_u64(), res_digest); + + // Reference + for i in 0..cols { + module_ref.vec_znx_dft_sub_ba_inplace(&mut res_dft_ref, i, &a_dft_ref, i); + module_test.vec_znx_dft_sub_ba_inplace(&mut res_dft_test, i, &a_dft_test, i); + } + + assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest); + assert_eq!(a_dft_test.digest_u64(), a_dft_test_digest); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols, res_size); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols, res_size); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } +} diff --git a/poulpy-hal/src/test_suite/vmp.rs b/poulpy-hal/src/test_suite/vmp.rs new file mode 100644 index 0000000..e46d194 --- /dev/null +++ b/poulpy-hal/src/test_suite/vmp.rs @@ -0,0 +1,384 @@ +use crate::{ + api::{ + ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, + VecZnxIdftApplyConsume, VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftAddTmpBytes, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, + }, + layouts::{DataViewMut, DigestU64, FillUniform, MatZnx, Module, ScratchOwned, VecZnx, VecZnxBig}, + source::Source, +}; +use rand::RngCore; + +use crate::layouts::{Backend, VecZnxDft, VmpPMat}; + +pub fn test_vmp_apply_dft(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: ModuleNew
+ + VmpApplyDftTmpBytes + + VmpApplyDft
+ + VmpPMatAlloc
+ + VecZnxDftAlloc
+ + VmpPrepare
+ + VecZnxDftAlloc
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalize
, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: ModuleNew + + VmpApplyDftTmpBytes + + VmpApplyDft + + VmpPMatAlloc + + VecZnxDftAlloc + + VmpPrepare + + VecZnxDftAlloc + + VecZnxIdftApplyConsume + + VecZnxBigNormalize, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let max_size: usize = 4; + let max_cols: usize = 2; + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch_ref: ScratchOwned
= + ScratchOwned::alloc(module_ref.vmp_apply_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size)); + let mut scratch_test: ScratchOwned = + ScratchOwned::alloc(module_test.vmp_apply_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size)); + + for cols_in in 1..max_cols + 1 { + for cols_out in 1..max_cols + 1 { + for size_in in 1..max_size + 1 { + for size_out in 1..max_size + 1 { + let rows: usize = cols_in; + + let mut a: VecZnx> = VecZnx::alloc(n, cols_in, size_in); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + let mut mat: MatZnx> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out); + mat.fill_uniform(basek, &mut source); + let mat_digest: u64 = mat.digest_u64(); + + let mut pmat_ref: VmpPMat, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out); + let mut pmat_test: VmpPMat, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out); + + module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow()); + module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow()); + + assert_eq!(mat.digest_u64(), mat_digest); + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out); + + source.fill_bytes(res_dft_ref.data_mut()); + source.fill_bytes(res_dft_test.data_mut()); + + module_ref.vmp_apply_dft(&mut res_dft_ref, &a, &pmat_ref, scratch_ref.borrow()); + module_test.vmp_apply_dft(&mut res_dft_test, &a, &pmat_test, scratch_test.borrow()); + + assert_eq!(a.digest_u64(), a_digest); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols_out, size_out); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols_out, size_out); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols_out { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } + } +} + +pub fn test_vmp_apply_dft_to_dft(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: ModuleNew
+ + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft
+ + VmpPMatAlloc
+ + VecZnxDftAlloc
+ + VmpPrepare
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalize
+ + VecZnxDftApply
, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: ModuleNew + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpPMatAlloc + + VecZnxDftAlloc + + VmpPrepare + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxDftApply, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let max_size: usize = 4; + let max_cols: usize = 2; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc( + module_ref.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size), + ); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc( + module_test.vmp_apply_dft_to_dft_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size), + ); + + for cols_in in 1..max_cols + 1 { + for cols_out in 1..max_cols + 1 { + for size_in in 1..max_size + 1 { + for size_out in 1..max_size + 1 { + let rows: usize = size_in; + + let mut a: VecZnx> = VecZnx::alloc(n, cols_in, size_in); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols_in, size_in); + + for j in 0..cols_in { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let mut mat: MatZnx> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out); + mat.fill_uniform(basek, &mut source); + let mat_digest: u64 = mat.digest_u64(); + + let mut pmat_ref: VmpPMat, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out); + let mut pmat_test: VmpPMat, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out); + + module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow()); + module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow()); + + assert_eq!(mat.digest_u64(), mat_digest); + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out); + + source.fill_bytes(res_dft_ref.data_mut()); + source.fill_bytes(res_dft_test.data_mut()); + + module_ref.vmp_apply_dft_to_dft( + &mut res_dft_ref, + &a_dft_ref, + &pmat_ref, + scratch_ref.borrow(), + ); + module_test.vmp_apply_dft_to_dft( + &mut res_dft_test, + &a_dft_test, + &pmat_test, + scratch_test.borrow(), + ); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols_out, size_out); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols_out, size_out); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols_out { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } + } +} + +pub fn test_vmp_apply_dft_to_dft_add(basek: usize, module_ref: &Module
, module_test: &Module) +where + Module
: ModuleNew
+ + VmpApplyDftToDftAddTmpBytes + + VmpApplyDftToDftAdd
+ + VmpPMatAlloc
+ + VecZnxDftAlloc
+ + VmpPrepare
+ + VecZnxIdftApplyConsume
+ + VecZnxBigNormalize
+ + VecZnxDftApply
, + ScratchOwned
: ScratchOwnedAlloc
+ ScratchOwnedBorrow
, + Module: ModuleNew + + VmpApplyDftToDftAddTmpBytes + + VmpApplyDftToDftAdd + + VmpPMatAlloc + + VecZnxDftAlloc + + VmpPrepare + + VecZnxIdftApplyConsume + + VecZnxBigNormalize + + VecZnxDftApply, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + assert_eq!(module_ref.n(), module_test.n()); + let n: usize = module_ref.n(); + + let max_size: usize = 4; + let max_cols: usize = 2; + + let mut source: Source = Source::new([0u8; 32]); + + let mut scratch_ref: ScratchOwned
= ScratchOwned::alloc( + module_ref.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size), + ); + let mut scratch_test: ScratchOwned = ScratchOwned::alloc( + module_test.vmp_apply_dft_to_dft_add_tmp_bytes(max_size, max_size, max_size, max_cols, max_cols, max_size), + ); + + for cols_in in 1..max_cols + 1 { + for cols_out in 1..max_cols + 1 { + for size_in in 1..max_size + 1 { + for size_out in 1..max_size + 1 { + let rows: usize = size_in; + + let mut a: VecZnx> = VecZnx::alloc(n, cols_in, size_in); + a.fill_uniform(basek, &mut source); + let a_digest: u64 = a.digest_u64(); + + let mut a_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in); + let mut a_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols_in, size_in); + + for j in 0..cols_in { + module_ref.vec_znx_dft_apply(1, 0, &mut a_dft_ref, j, &a, j); + module_test.vec_znx_dft_apply(1, 0, &mut a_dft_test, j, &a, j); + } + + assert_eq!(a.digest_u64(), a_digest); + + let mut mat: MatZnx> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out); + mat.fill_uniform(basek, &mut source); + let mat_digest: u64 = mat.digest_u64(); + + let mut pmat_ref: VmpPMat, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out); + let mut pmat_test: VmpPMat, BT> = module_test.vmp_pmat_alloc(rows, cols_in, cols_out, size_out); + + module_ref.vmp_prepare(&mut pmat_ref, &mat, scratch_ref.borrow()); + module_test.vmp_prepare(&mut pmat_test, &mat, scratch_test.borrow()); + + assert_eq!(mat.digest_u64(), mat_digest); + + for limb_offset in 0..size_out { + let mut res: VecZnx> = VecZnx::alloc(n, cols_out, size_out); + res.fill_uniform(basek, &mut source); + let res_digest: u64 = res.digest_u64(); + + let mut res_dft_ref: VecZnxDft, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out); + let mut res_dft_test: VecZnxDft, BT> = module_test.vec_znx_dft_alloc(cols_out, size_out); + + for j in 0..cols_out { + module_ref.vec_znx_dft_apply(1, 0, &mut res_dft_ref, j, &res, j); + module_test.vec_znx_dft_apply(1, 0, &mut res_dft_test, j, &res, j); + } + + assert_eq!(res.digest_u64(), res_digest); + + module_ref.vmp_apply_dft_to_dft_add( + &mut res_dft_ref, + &a_dft_ref, + &pmat_ref, + limb_offset * cols_out, + scratch_ref.borrow(), + ); + module_test.vmp_apply_dft_to_dft_add( + &mut res_dft_test, + &a_dft_test, + &pmat_test, + limb_offset * cols_out, + scratch_test.borrow(), + ); + + let res_big_ref: VecZnxBig, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref); + let res_big_test: VecZnxBig, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test); + + let mut res_small_ref: VecZnx> = VecZnx::alloc(n, cols_out, size_out); + let mut res_small_test: VecZnx> = VecZnx::alloc(n, cols_out, size_out); + + let res_ref_digest: u64 = res_big_ref.digest_u64(); + let res_test_digest: u64 = res_big_test.digest_u64(); + + for j in 0..cols_out { + module_ref.vec_znx_big_normalize( + basek, + &mut res_small_ref, + j, + &res_big_ref, + j, + scratch_ref.borrow(), + ); + module_test.vec_znx_big_normalize( + basek, + &mut res_small_test, + j, + &res_big_test, + j, + scratch_test.borrow(), + ); + } + + assert_eq!(res_big_ref.digest_u64(), res_ref_digest); + assert_eq!(res_big_test.digest_u64(), res_test_digest); + + assert_eq!(res_small_ref, res_small_test); + } + } + } + } + } +} diff --git a/poulpy-hal/src/tests/mod.rs b/poulpy-hal/src/tests/mod.rs deleted file mode 100644 index a805d82..0000000 --- a/poulpy-hal/src/tests/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod serialization; -pub mod vec_znx; -pub mod vmp_pmat; diff --git a/poulpy-hal/src/tests/vec_znx/encoding.rs b/poulpy-hal/src/tests/vec_znx/encoding.rs deleted file mode 100644 index 90f8c62..0000000 --- a/poulpy-hal/src/tests/vec_znx/encoding.rs +++ /dev/null @@ -1,50 +0,0 @@ -use crate::{ - layouts::{VecZnx, ZnxInfos, ZnxViewMut}, - source::Source, -}; - -pub fn test_vec_znx_encode_vec_i64_lo_norm() { - let n: usize = 32; - let basek: usize = 17; - let size: usize = 5; - let k: usize = size * basek - 5; - let mut a: VecZnx> = VecZnx::alloc(n, 2, size); - let mut source: Source = Source::new([0u8; 32]); - let raw: &mut [i64] = a.raw_mut(); - raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut() - .for_each(|x| *x = (source.next_i64() << 56) >> 56); - a.encode_vec_i64(basek, col_i, k, &have, 10); - let mut want: Vec = vec![i64::default(); n]; - a.decode_vec_i64(basek, col_i, k, &mut want); - assert_eq!(have, want, "{:?} != {:?}", &have, &want); - }); -} - -pub fn test_vec_znx_encode_vec_i64_hi_norm() { - let n: usize = 32; - let basek: usize = 17; - let size: usize = 5; - for k in [1, basek / 2, size * basek - 5] { - let mut a: VecZnx> = VecZnx::alloc(n, 2, size); - let mut source = Source::new([0u8; 32]); - let raw: &mut [i64] = a.raw_mut(); - raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut().for_each(|x| { - if k < 64 { - *x = source.next_u64n(1 << k, (1 << k) - 1) as i64; - } else { - *x = source.next_i64(); - } - }); - a.encode_vec_i64(basek, col_i, k, &have, 63); - let mut want: Vec = vec![i64::default(); n]; - a.decode_vec_i64(basek, col_i, k, &mut want); - assert_eq!(have, want, "{:?} != {:?}", &have, &want); - }) - } -} diff --git a/poulpy-hal/src/tests/vec_znx/generics.rs b/poulpy-hal/src/tests/vec_znx/generics.rs deleted file mode 100644 index 3b31c84..0000000 --- a/poulpy-hal/src/tests/vec_znx/generics.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::{ - api::{VecZnxAddNormal, VecZnxFillUniform}, - layouts::{Backend, Module, VecZnx, ZnxView}, - source::Source, -}; - -pub fn test_vec_znx_fill_uniform(module: &Module) -where - Module: VecZnxFillUniform, -{ - let n: usize = module.n(); - let basek: usize = 17; - let size: usize = 5; - let mut source: Source = Source::new([0u8; 32]); - let cols: usize = 2; - let zero: Vec = vec![0; n]; - let one_12_sqrt: f64 = 0.28867513459481287; - (0..cols).for_each(|col_i| { - let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); - module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source); - (0..cols).for_each(|col_j| { - if col_j != col_i { - (0..size).for_each(|limb_i| { - assert_eq!(a.at(col_j, limb_i), zero); - }) - } else { - let std: f64 = a.std(basek, col_i); - assert!( - (std - one_12_sqrt).abs() < 0.01, - "std={} ~!= {}", - std, - one_12_sqrt - ); - } - }) - }); -} - -pub fn test_vec_znx_add_normal(module: &Module) -where - Module: VecZnxAddNormal, -{ - let n: usize = module.n(); - let basek: usize = 17; - let k: usize = 2 * 17; - let size: usize = 5; - let sigma: f64 = 3.2; - let bound: f64 = 6.0 * sigma; - let mut source: Source = Source::new([0u8; 32]); - let cols: usize = 2; - let zero: Vec = vec![0; n]; - let k_f64: f64 = (1u64 << k as u64) as f64; - (0..cols).for_each(|col_i| { - let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size); - module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound); - (0..cols).for_each(|col_j| { - if col_j != col_i { - (0..size).for_each(|limb_i| { - assert_eq!(a.at(col_j, limb_i), zero); - }) - } else { - let std: f64 = a.std(basek, col_i) * k_f64; - assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma); - } - }) - }); -} diff --git a/poulpy-hal/src/tests/vec_znx/mod.rs b/poulpy-hal/src/tests/vec_znx/mod.rs deleted file mode 100644 index a87ee38..0000000 --- a/poulpy-hal/src/tests/vec_znx/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod generics; -pub use generics::*; - -#[cfg(test)] -mod encoding; diff --git a/poulpy-hal/src/tests/vmp_pmat/mod.rs b/poulpy-hal/src/tests/vmp_pmat/mod.rs deleted file mode 100644 index 2847cc3..0000000 --- a/poulpy-hal/src/tests/vmp_pmat/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod vmp_apply; - -pub use vmp_apply::*; diff --git a/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs b/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs deleted file mode 100644 index ad6511a..0000000 --- a/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs +++ /dev/null @@ -1,112 +0,0 @@ -use crate::{ - api::{ - DFT, IDFTTmpA, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VmpApplyDftToDft, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, - }, - layouts::{MatZnx, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos, ZnxViewMut}, - oep::{ - DFTImpl, IDFTTmpAImpl, ModuleNewImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, VecZnxBigAllocImpl, - VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxDftAllocImpl, VmpApplyDftToDftImpl, - VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocImpl, VmpPMatPrepareImpl, - }, -}; - -use crate::layouts::Backend; - -pub fn test_vmp_apply() -where - B: Backend - + ModuleNewImpl - + VmpApplyDftToDftTmpBytesImpl - + VecZnxBigNormalizeTmpBytesImpl - + VmpPMatAllocImpl - + VecZnxDftAllocImpl - + VecZnxBigAllocImpl - + VmpPMatPrepareImpl - + DFTImpl - + VmpApplyDftToDftImpl - + IDFTTmpAImpl - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + VecZnxBigNormalizeImpl, -{ - let log_n: i32 = 5; - let n: usize = 1 << log_n; - - let module: Module = Module::::new(n as u64); - let basek: usize = 15; - let a_size: usize = 5; - let mat_size: usize = 6; - let res_size: usize = a_size; - - [1, 2].iter().for_each(|cols_in| { - [1, 2].iter().for_each(|cols_out| { - let a_cols: usize = *cols_in; - let res_cols: usize = *cols_out; - - let mat_rows: usize = a_size; - let mat_cols_in: usize = a_cols; - let mat_cols_out: usize = res_cols; - - let mut scratch = ScratchOwned::alloc( - module.vmp_apply_dft_to_dft_tmp_bytes( - res_size, - a_size, - mat_rows, - mat_cols_in, - mat_cols_out, - mat_size, - ) | module.vec_znx_big_normalize_tmp_bytes(), - ); - - let mut a: VecZnx> = VecZnx::alloc(n, a_cols, a_size); - - (0..a_cols).for_each(|i| { - a.at_mut(i, a_size - 1)[i + 1] = 1; - }); - - let mut vmp: VmpPMat, B> = module.vmp_pmat_alloc(mat_rows, mat_cols_in, mat_cols_out, mat_size); - - let mut c_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(mat_cols_out, mat_size); - let mut c_big: VecZnxBig, B> = module.vec_znx_big_alloc(mat_cols_out, mat_size); - - let mut mat: MatZnx> = MatZnx::alloc(n, mat_rows, mat_cols_in, mat_cols_out, mat_size); - - // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. - (0..a.size()).for_each(|row_i| { - (0..mat_cols_in).for_each(|col_in_i| { - (0..mat_cols_out).for_each(|col_out_i| { - let idx = 1 + col_in_i * mat_cols_out + col_out_i; - mat.at_mut(row_i, col_in_i).at_mut(col_out_i, row_i)[idx] = 1_i64; // X^{idx} - }); - }); - }); - - module.vmp_prepare(&mut vmp, &mat, scratch.borrow()); - - let mut a_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(a_cols, a_size); - (0..a_cols).for_each(|i| { - module.dft(1, 0, &mut a_dft, i, &a, i); - }); - - module.vmp_apply_dft_to_dft(&mut c_dft, &a_dft, &vmp, scratch.borrow()); - - let mut res_have_vi64: Vec = vec![i64::default(); n]; - - let mut res_have: VecZnx> = VecZnx::alloc(n, res_cols, res_size); - (0..mat_cols_out).for_each(|i| { - module.idft_tmp_a(&mut c_big, i, &mut c_dft, i); - module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); - }); - - (0..mat_cols_out).for_each(|col_i| { - let mut res_want_vi64: Vec = vec![i64::default(); n]; - (0..a_cols).for_each(|i| { - res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; - }); - res_have.decode_vec_i64(basek, col_i, basek * a_size, &mut res_have_vi64); - assert_eq!(res_have_vi64, res_want_vi64); - }); - }); - }); -} diff --git a/poulpy-schemes/examples/circuit_bootstrapping.rs b/poulpy-schemes/examples/circuit_bootstrapping.rs index 24e323b..1852bfe 100644 --- a/poulpy-schemes/examples/circuit_bootstrapping.rs +++ b/poulpy-schemes/examples/circuit_bootstrapping.rs @@ -13,7 +13,7 @@ use poulpy_hal::{ source::Source, }; -use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_schemes::tfhe::{ blind_rotation::CGGI, @@ -27,7 +27,7 @@ fn main() { let n_glwe: usize = 1024; // Module provides access to the backend arithmetic - let module: Module = Module::::new(n_glwe as u64); + let module: Module = Module::::new(n_glwe as u64); // Base 2 loga let basek: usize = 13; @@ -75,7 +75,7 @@ fn main() { let k_tsk: usize = (rows_tsk + 1) * basek; // Scratch space (4MB) - let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); // Secret key sampling source let mut source_xs: Source = Source::new([1u8; 32]); @@ -97,7 +97,7 @@ fn main() { // sk_glwe.fill_zero(); // GLWE secret prepared (opaque backend dependant write only struct) - let sk_glwe_prepared: GLWESecretPrepared, FFT64> = sk_glwe.prepare_alloc(&module, scratch.borrow()); + let sk_glwe_prepared: GLWESecretPrepared, FFT64Spqlios> = sk_glwe.prepare_alloc(&module, scratch.borrow()); // Plaintext value to circuit bootstrap let data: i64 = 1 % (1 << k_lwe_pt); @@ -142,7 +142,8 @@ fn main() { let mut res: GGSWCiphertext> = GGSWCiphertext::alloc(n_glwe, basek, k_ggsw_res, rows_ggsw_res, 1, rank); // Circuit bootstrapping key prepared (opaque backend dependant write only struct) - let cbt_prepared: CircuitBootstrappingKeyPrepared, CGGI, FFT64> = cbt_key.prepare_alloc(&module, scratch.borrow()); + let cbt_prepared: CircuitBootstrappingKeyPrepared, CGGI, FFT64Spqlios> = + cbt_key.prepare_alloc(&module, scratch.borrow()); // Apply circuit bootstrapping: LWE(data * 2^{- (k_lwe_pt + 2)}) -> GGSW(data) let now: Instant = Instant::now(); @@ -193,7 +194,7 @@ fn main() { ); // Prepare GGSW output of circuit bootstrapping (opaque backend dependant write only struct) - let res_prepared: GGSWCiphertextPrepared, FFT64> = res.prepare_alloc(&module, scratch.borrow()); + let res_prepared: GGSWCiphertextPrepared, FFT64Spqlios> = res.prepare_alloc(&module, scratch.borrow()); // Apply GLWE x GGSW ct_glwe.external_product_inplace(&module, &res_prepared, scratch.borrow()); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs index 65633c2..a8bfb87 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs @@ -1,12 +1,12 @@ use itertools::izip; use poulpy_hal::{ api::{ - DFT, IDFT, IDFTConsume, ScratchAvailable, SvpApply, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, - TakeVecZnxDftSlice, TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftSubABInplace, - VecZnxDftZero, VecZnxIDFTTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, + ScratchAvailable, SvpApplyDftToDft, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, + TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, + VecZnxDftSubABInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, + VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, + VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxView, ZnxZero}, }; @@ -36,7 +36,7 @@ where + VmpApplyDftToDftTmpBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes - + VecZnxIDFTTmpBytes + + VecZnxIdftApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { let brk_size: usize = k_brk.div_ceil(basek); @@ -59,7 +59,7 @@ where + acc_dft_add + vmp_res + vmp_xai - + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()))) + + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes()))) } else { GLWECiphertext::bytes_of(module.n(), basek, k_res, rank) + GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank) @@ -73,13 +73,13 @@ where + SvpPPolAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes - + VecZnxIDFTTmpBytes - + IDFT + + VecZnxIdftApplyTmpBytes + + VecZnxIdftApply + VecZnxDftAdd + VecZnxDftAddInplace - + DFT + + VecZnxDftApply + VecZnxDftZero - + SvpApply + + SvpApplyDftToDft + VecZnxDftSubABInplace + VecZnxBigAddSmallInplace + VecZnxRotate @@ -88,10 +88,10 @@ where + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy - + VecZnxMulXpMinusOneInplace + + VecZnxMulXpMinusOneInplace + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize + VecZnxNormalizeTmpBytes, Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + TakeVecZnx + ScratchAvailable, @@ -135,13 +135,13 @@ fn execute_block_binary_extended( + SvpPPolAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes - + VecZnxIDFTTmpBytes - + IDFT + + VecZnxIdftApplyTmpBytes + + VecZnxIdftApply + VecZnxDftAdd + VecZnxDftAddInplace - + DFT + + VecZnxDftApply + VecZnxDftZero - + SvpApply + + SvpApplyDftToDft + VecZnxDftSubABInplace + VecZnxBigAddSmallInplace + VecZnxRotate @@ -150,7 +150,7 @@ fn execute_block_binary_extended( + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy - + VecZnxMulXpMinusOneInplace + + VecZnxMulXpMinusOneInplace + VecZnxBigNormalize + VmpApplyDftToDft, Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, @@ -161,11 +161,11 @@ fn execute_block_binary_extended( let rows: usize = brk.rows(); let cols: usize = res.rank() + 1; - let (mut acc, scratch1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size()); - let (mut acc_dft, scratch2) = scratch1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, rows); - let (mut vmp_res, scratch3) = scratch2.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); - let (mut acc_add_dft, scratch4) = scratch3.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); - let (mut vmp_xai, scratch5) = scratch4.take_vec_znx_dft(n_glwe, 1, brk.size()); + let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size()); + let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, rows); + let (mut vmp_res, scratch_3) = scratch_2.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); + let (mut acc_add_dft, scratch_4) = scratch_3.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, brk.size()); + let (mut vmp_xai, scratch_5) = scratch_4.take_vec_znx_dft(n_glwe, 1, brk.size()); (0..extension_factor).for_each(|i| { acc[i].zero(); @@ -208,7 +208,7 @@ fn execute_block_binary_extended( .for_each(|(ai, ski)| { (0..extension_factor).for_each(|i| { (0..cols).for_each(|j| { - module.dft(1, 0, &mut acc_dft[i], j, &acc[i], j); + module.vec_znx_dft_apply(1, 0, &mut acc_dft[i], j, &acc[i], j); }); module.vec_znx_dft_zero(&mut acc_add_dft[i]) }); @@ -221,7 +221,7 @@ fn execute_block_binary_extended( // vmp_res = DFT(acc) * BRK[i] (0..extension_factor).for_each(|i| { - module.vmp_apply_dft_to_dft(&mut vmp_res[i], &acc_dft[i], skii.data(), scratch5); + module.vmp_apply_dft_to_dft(&mut vmp_res[i], &acc_dft[i], skii.data(), scratch_5); }); // Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1) @@ -231,7 +231,7 @@ fn execute_block_binary_extended( // DFT X^{-ai} (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { - module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], i); + module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], i); module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_xai, 0); module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i); }); @@ -247,7 +247,7 @@ fn execute_block_binary_extended( if (ai_hi + 1) & (two_n - 1) != 0 { for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) { (0..cols).for_each(|k| { - module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi + 1], 0, &vmp_res[j], k); + module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi + 1], 0, &vmp_res[j], k); module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0); module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); }); @@ -259,7 +259,7 @@ fn execute_block_binary_extended( // Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai} for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) { (0..cols).for_each(|k| { - module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], k); + module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], k); module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0); module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k); }); @@ -269,11 +269,11 @@ fn execute_block_binary_extended( }); { - let (mut acc_add_big, scratch7) = scratch5.take_vec_znx_big(n_glwe, 1, brk.size()); + let (mut acc_add_big, scratch7) = scratch_5.take_vec_znx_big(n_glwe, 1, brk.size()); (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { - module.idft(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7); + module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7); module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i); module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7); }); @@ -302,13 +302,13 @@ fn execute_block_binary( + SvpPPolAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes - + VecZnxIDFTTmpBytes - + IDFT + + VecZnxIdftApplyTmpBytes + + VecZnxIdftApply + VecZnxDftAdd + VecZnxDftAddInplace - + DFT + + VecZnxDftApply + VecZnxDftZero - + SvpApply + + SvpApplyDftToDft + VecZnxDftSubABInplace + VecZnxBigAddSmallInplace + VecZnxRotate @@ -317,7 +317,7 @@ fn execute_block_binary( + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy - + VecZnxMulXpMinusOneInplace + + VecZnxMulXpMinusOneInplace + VmpApplyDftToDft + VecZnxBigNormalize, Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, @@ -351,10 +351,10 @@ fn execute_block_binary( // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_dft, scratch1) = scratch.take_vec_znx_dft(n_glwe, cols, rows); - let (mut vmp_res, scratch2) = scratch1.take_vec_znx_dft(n_glwe, cols, brk.size()); - let (mut acc_add_dft, scratch3) = scratch2.take_vec_znx_dft(n_glwe, cols, brk.size()); - let (mut vmp_xai, scratch4) = scratch3.take_vec_znx_dft(n_glwe, 1, brk.size()); + let (mut acc_dft, scratch_1) = scratch.take_vec_znx_dft(n_glwe, cols, rows); + let (mut vmp_res, scratch_2) = scratch_1.take_vec_znx_dft(n_glwe, cols, brk.size()); + let (mut acc_add_dft, scratch_3) = scratch_2.take_vec_znx_dft(n_glwe, cols, brk.size()); + let (mut vmp_xai, scratch_4) = scratch_3.take_vec_znx_dft(n_glwe, 1, brk.size()); let x_pow_a: &Vec, B>>; if let Some(b) = &brk.x_pow_a { @@ -369,7 +369,7 @@ fn execute_block_binary( ) .for_each(|(ai, ski)| { (0..cols).for_each(|j| { - module.dft(1, 0, &mut acc_dft, j, &out_mut.data, j); + module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, &out_mut.data, j); }); module.vec_znx_dft_zero(&mut acc_add_dft); @@ -378,23 +378,23 @@ fn execute_block_binary( let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize; // vmp_res = DFT(acc) * BRK[i] - module.vmp_apply_dft_to_dft(&mut vmp_res, &acc_dft, skii.data(), scratch4); + module.vmp_apply_dft_to_dft(&mut vmp_res, &acc_dft, skii.data(), scratch_4); // DFT(X^ai -1) * (DFT(acc) * BRK[i]) (0..cols).for_each(|i| { - module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_pos], 0, &vmp_res, i); + module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_pos], 0, &vmp_res, i); module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_xai, 0); module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft, i, &vmp_res, i); }); }); { - let (mut acc_add_big, scratch5) = scratch4.take_vec_znx_big(n_glwe, 1, brk.size()); + let (mut acc_add_big, scratch_5) = scratch_4.take_vec_znx_big(n_glwe, 1, brk.size()); (0..cols).for_each(|i| { - module.idft(&mut acc_add_big, 0, &acc_add_dft, i, scratch5); + module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft, i, scratch_5); module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i); - module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch5); + module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch_5); }); } }); @@ -416,13 +416,13 @@ fn execute_standard( + SvpPPolAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes - + VecZnxIDFTTmpBytes - + IDFT + + VecZnxIdftApplyTmpBytes + + VecZnxIdftApply + VecZnxDftAdd + VecZnxDftAddInplace - + DFT + + VecZnxDftApply + VecZnxDftZero - + SvpApply + + SvpApplyDftToDft + VecZnxDftSubABInplace + VecZnxBigAddSmallInplace + VecZnxRotate @@ -431,10 +431,10 @@ fn execute_standard( + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy - + VecZnxMulXpMinusOneInplace + + VecZnxMulXpMinusOneInplace + VmpApplyDftToDft + VmpApplyDftToDftAdd - + IDFTConsume + + VecZnxIdftApplyConsume + VecZnxBigNormalize + VecZnxNormalizeTmpBytes, Scratch: TakeVecZnxDftSlice + TakeVecZnxDft + TakeVecZnxBig + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx, @@ -492,16 +492,16 @@ fn execute_standard( module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_tmp, scratch1) = scratch.take_glwe_ct(out_mut.n(), basek, out_mut.k(), out_mut.rank()); + let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(out_mut.n(), basek, out_mut.k(), out_mut.rank()); // TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs // TODO: first iteration can be optimized to be a gglwe product izip!(a.iter(), brk.data.iter()).for_each(|(ai, ski)| { // acc_tmp = sk[i] * acc - acc_tmp.external_product(module, &out_mut, ski, scratch1); + acc_tmp.external_product(module, &out_mut, ski, scratch_1); // acc_tmp = (sk[i] * acc) * (X^{ai} - 1) - acc_tmp.mul_xp_minus_one_inplace(module, *ai); + acc_tmp.mul_xp_minus_one_inplace(module, *ai, scratch_1); // acc = acc + (sk[i] * acc) * (X^{ai} - 1) out_mut.add_inplace(module, &acc_tmp); @@ -509,7 +509,7 @@ fn execute_standard( // We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}] // on top of each others, thus ~ 2^{63-basek} additions are supported before overflow. - out_mut.normalize_inplace(module, scratch1); + out_mut.normalize_inplace(module, scratch_1); } pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) { diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs index 17ef004..1f33341 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs @@ -1,8 +1,9 @@ use poulpy_hal::{ api::{ - DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VmpPMatAlloc, VmpPrepare, + ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, + VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, + VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut}, source::Source, @@ -50,9 +51,9 @@ where Module: VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -162,9 +163,9 @@ impl BlindRotationKeyCompressed { Module: VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key.rs b/poulpy-schemes/src/tfhe/blind_rotation/key.rs index 3a28ec0..6b6163a 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key.rs @@ -80,10 +80,10 @@ impl Reset for BlindRotationKey { } impl FillUniform for BlindRotationKey { - fn fill_uniform(&mut self, source: &mut Source) { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() - .for_each(|key| key.fill_uniform(source)); + .for_each(|key| key.fill_uniform(log_bound, source)); } } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs index 14837e3..7fa463e 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key_compressed.rs @@ -59,10 +59,10 @@ impl Reset for BlindRotationKeyCompressed FillUniform for BlindRotationKeyCompressed { - fn fill_uniform(&mut self, source: &mut Source) { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() - .for_each(|key| key.fill_uniform(source)); + .for_each(|key| key.fill_uniform(log_bound, source)); } } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs index eab41ea..4ff19f8 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs @@ -1,10 +1,9 @@ use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, - VecZnxSwithcDegree, + VecZnxRotateInplaceTmpBytes, VecZnxSwitchRing, }, layouts::{Backend, Module, ScratchOwned, VecZnx, ZnxInfos, ZnxViewMut}, - oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; #[derive(Debug, Clone, Copy)] @@ -69,15 +68,22 @@ impl LookUpTable { self.rot_dir = rot_dir } - pub fn set(&mut self, module: &Module, f: &[i64], k: usize) + pub fn set(&mut self, module: &Module, f: &[i64], k: usize) where - Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, - B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + Module: VecZnxRotateInplace + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + VecZnxSwitchRing + + VecZnxCopy + + VecZnxRotateInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { assert!(f.len() <= module.n()); let basek: usize = self.basek; + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes()); + // Get the number minimum limb to store the message modulus let limbs: usize = k.div_ceil(basek); @@ -124,17 +130,15 @@ impl LookUpTable { if self.extension_factor() > 1 { (0..self.extension_factor()).for_each(|i| { - module.vec_znx_switch_degree(&mut self.data[i], 0, &lut_full, 0); + module.vec_znx_switch_ring(&mut self.data[i], 0, &lut_full, 0); if i < self.extension_factor() { - module.vec_znx_rotate_inplace(-1, &mut lut_full, 0); + module.vec_znx_rotate_inplace(-1, &mut lut_full, 0, scratch.borrow()); } }); } else { module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); } - let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes()); - self.data.iter_mut().for_each(|a| { module.vec_znx_normalize_inplace(self.basek, a, 0, scratch.borrow()); }); @@ -147,23 +151,26 @@ impl LookUpTable { #[allow(dead_code)] pub(crate) fn rotate(&mut self, module: &Module, k: i64) where - Module: VecZnxRotateInplace, + Module: VecZnxRotateInplace + VecZnxRotateInplaceTmpBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { let extension_factor: usize = self.extension_factor(); let two_n: usize = 2 * self.data[0].n(); let two_n_ext: usize = two_n * extension_factor; + let mut scratch: ScratchOwned<_> = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes()); + let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize; let k_hi: usize = k_pos / extension_factor; let k_lo: usize = k_pos % extension_factor; (0..extension_factor - k_lo).for_each(|i| { - module.vec_znx_rotate_inplace(k_hi as i64, &mut self.data[i], 0); + module.vec_znx_rotate_inplace(k_hi as i64, &mut self.data[i], 0, scratch.borrow()); }); (extension_factor - k_lo..extension_factor).for_each(|i| { - module.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut self.data[i], 0); + module.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut self.data[i], 0, scratch.borrow()); }); self.data.rotate_right(k_lo); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index e522462..4304a7f 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -1,12 +1,13 @@ use poulpy_hal::{ api::{ - DFT, IDFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, - SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, - VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, - VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftSubABInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIDFTTmpBytes, - VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, - VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, + VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftSubABInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIdftApply, + VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, + ZnFillUniform, ZnNormalizeInplace, }, layouts::{Backend, Module, ScratchOwned, ZnxView}, oep::{ @@ -33,13 +34,13 @@ where + SvpPPolAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes - + VecZnxIDFTTmpBytes - + IDFT + + VecZnxIdftApplyTmpBytes + + VecZnxIdftApply + VecZnxDftAdd + VecZnxDftAddInplace - + DFT + + VecZnxDftApply + VecZnxDftZero - + SvpApply + + SvpApplyDftToDft + VecZnxDftSubABInplace + VecZnxBigAddSmallInplace + VecZnxRotate @@ -48,19 +49,19 @@ where + VecZnxNormalize + VecZnxNormalizeInplace + VecZnxCopy - + VecZnxMulXpMinusOneInplace + + VecZnxMulXpMinusOneInplace + SvpPrepare + SvpPPolAlloc - + SvpApplyInplace - + IDFTConsume + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxBigAddInplace + VecZnxBigNormalize + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxAddNormal + VecZnxAddScalarInplace - + VecZnxRotateInplace - + VecZnxSwithcDegree + + VecZnxRotateInplace + + VecZnxSwitchRing + VecZnxSub + VmpPMatAlloc + VmpPrepare @@ -68,6 +69,7 @@ where + VmpApplyDftToDftAdd + ZnFillUniform + ZnAddNormal + + VecZnxRotateInplaceTmpBytes + ZnNormalizeInplace, B: Backend + VecZnxDftAllocBytesImpl diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs index 6beb131..a57d493 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs @@ -1,7 +1,10 @@ use std::vec; use poulpy_hal::{ - api::{VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxSwithcDegree}, + api::{ + VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, + VecZnxSwitchRing, + }, layouts::{Backend, Module}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; @@ -10,7 +13,12 @@ use crate::tfhe::blind_rotation::{DivRound, LookUpTable}; pub fn test_lut_standard(module: &Module) where - Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, + Module: VecZnxRotateInplace + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + VecZnxSwitchRing + + VecZnxCopy + + VecZnxRotateInplaceTmpBytes, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { let basek: usize = 20; @@ -45,7 +53,12 @@ where pub fn test_lut_extended(module: &Module) where - Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, + Module: VecZnxRotateInplace + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + VecZnxSwitchRing + + VecZnxCopy + + VecZnxRotateInplaceTmpBytes, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { let basek: usize = 20; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs index 7310750..15c6fdb 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_serialization.rs @@ -1,4 +1,4 @@ -use poulpy_hal::tests::serialization::test_reader_writer_interface; +use poulpy_hal::test_suite::serialization::test_reader_writer_interface; use crate::tfhe::blind_rotation::{BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, CGGI}; diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs index feeee06..ecb2421 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/implementation/cpu_spqlios/fft64.rs @@ -1,4 +1,4 @@ -use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_hal::{api::ModuleNew, layouts::Module}; use crate::tfhe::blind_rotation::tests::{ @@ -8,30 +8,30 @@ use crate::tfhe::blind_rotation::tests::{ #[test] fn lut_standard() { - let module: Module = Module::::new(32); + let module: Module = Module::::new(32); test_lut_standard(&module); } #[test] fn lut_extended() { - let module: Module = Module::::new(32); + let module: Module = Module::::new(32); test_lut_extended(&module); } #[test] fn standard() { - let module: Module = Module::::new(512); + let module: Module = Module::::new(512); test_blind_rotation(&module, 224, 1, 1); } #[test] fn block_binary() { - let module: Module = Module::::new(512); + let module: Module = Module::::new(512); test_blind_rotation(&module, 224, 7, 1); } #[test] fn block_binary_extended() { - let module: Module = Module::::new(512); + let module: Module = Module::::new(512); test_blind_rotation(&module, 224, 7, 2); } diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index 2e01fc6..5497700 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -2,12 +2,13 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeMatZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, - TakeVecZnxSlice, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, - VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNegateInplace, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, - VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, TakeMatZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, + VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAddInplace, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, + VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, + VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, + VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ToOwnedDeep}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, @@ -26,14 +27,14 @@ use crate::tfhe::{ impl CirtuitBootstrappingExecute for CircuitBootstrappingKeyPrepared where - Module: VecZnxRotateInplace + Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxBigAutomorphismInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxDftCopy - + IDFTTmpA + + VecZnxIdftApplyTmpA + VecZnxSub + VecZnxAddInplace + VecZnxNegateInplace @@ -44,12 +45,13 @@ where + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + + VecZnxRotateInplaceTmpBytes + VecZnxBigAllocBytes + VecZnxDftAddInplace + VecZnxRotate, @@ -124,14 +126,14 @@ pub fn circuit_bootstrap_core( DRes: DataMut, DLwe: DataRef, DBrk: DataRef, - Module: VecZnxRotateInplace + Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxBigAutomorphismInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxDftCopy - + IDFTTmpA + + VecZnxIdftApplyTmpA + VecZnxSub + VecZnxAddInplace + VecZnxNegateInplace @@ -142,14 +144,15 @@ pub fn circuit_bootstrap_core( + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + VecZnxBigAllocBytes + VecZnxDftAddInplace + + VecZnxRotateInplaceTmpBytes + VecZnxRotate, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, Scratch: TakeVecZnxDftSlice @@ -199,10 +202,10 @@ pub fn circuit_bootstrap_core( } // TODO: separate GGSW k from output of blind rotation k - let (mut res_glwe, scratch1) = scratch.take_glwe_ct(n, basek, k, rank); - let (mut tmp_gglwe, scratch2) = scratch1.take_gglwe(n, basek, k, rows, 1, rank.max(1), rank); + let (mut res_glwe, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); + let (mut tmp_gglwe, scratch_2) = scratch_1.take_gglwe(n, basek, k, rows, 1, rank.max(1), rank); - key.brk.execute(module, &mut res_glwe, lwe, &lut, scratch2); + key.brk.execute(module, &mut res_glwe, lwe, &lut, scratch_2); let gap: usize = 2 * lut.drift / lut.extension_factor(); @@ -221,19 +224,19 @@ pub fn circuit_bootstrap_core( log_gap_out, log_domain, &key.atk, - scratch2, + scratch_2, ); } else { - tmp_glwe.trace(module, 0, module.log_n(), &res_glwe, &key.atk, scratch2); + tmp_glwe.trace(module, 0, module.log_n(), &res_glwe, &key.atk, scratch_2); } if i < rows { - res_glwe.rotate_inplace(module, -(gap as i64)); + res_glwe.rotate_inplace(module, -(gap as i64), scratch_2); } }); // Expands GGLWE to GGSW using GGLWE(s^2) - res.from_gglwe(module, &tmp_gglwe, &key.tsk, scratch2); + res.from_gglwe(module, &tmp_gglwe, &key.tsk, scratch_2); } #[allow(clippy::too_many_arguments)] @@ -249,14 +252,14 @@ fn post_process( ) where DataRes: DataMut, DataA: DataRef, - Module: VecZnxRotateInplace + Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxBigAutomorphismInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxDftCopy - + IDFTTmpA + + VecZnxIdftApplyTmpA + VecZnxSub + VecZnxAddInplace + VecZnxNegateInplace @@ -267,11 +270,11 @@ fn post_process( + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + VecZnxRotate, Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, @@ -297,7 +300,7 @@ fn post_process( let steps: i32 = 1 << log_domain; (0..steps).for_each(|i| { if i != 0 { - res.rotate_inplace(module, -(1 << log_gap_in)); + res.rotate_inplace(module, -(1 << log_gap_in), scratch); } cts.insert(i as usize * (1 << log_gap_out), res.to_owned_deep()); }); @@ -321,14 +324,14 @@ pub fn pack( auto_keys: &HashMap, B>>, scratch: &mut Scratch, ) where - Module: VecZnxRotateInplace + Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxBigAutomorphismInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxDftCopy - + IDFTTmpA + + VecZnxIdftApplyTmpA + VecZnxSub + VecZnxAddInplace + VecZnxNegateInplace @@ -339,11 +342,11 @@ pub fn pack( + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + VecZnxRotate, Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, @@ -400,14 +403,14 @@ fn combine( auto_key: &GGLWEAutomorphismKeyPrepared, scratch: &mut Scratch, ) where - Module: VecZnxRotateInplace + Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxBigAutomorphismInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxDftCopy - + IDFTTmpA + + VecZnxIdftApplyTmpA + VecZnxSub + VecZnxAddInplace + VecZnxNegateInplace @@ -418,11 +421,11 @@ fn combine( + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd - + DFT - + IDFTConsume + + VecZnxDftApply + + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + VecZnxRotate, Scratch: TakeVecZnx + TakeVecZnxDft + ScratchAvailable, @@ -446,15 +449,15 @@ fn combine( let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); // a = a * X^-t - a.rotate_inplace(module, -t); + a.rotate_inplace(module, -t, scratch_1); // tmp_b = a * X^-t - b tmp_b.sub(module, a, b); - tmp_b.rsh(module, 1); + tmp_b.rsh(module, 1, scratch_1); // a = a * X^-t + b a.add_inplace(module, b); - a.rsh(module, 1); + a.rsh(module, 1, scratch_1); tmp_b.normalize_inplace(module, scratch_1); @@ -468,9 +471,9 @@ fn combine( // a = a + b * X^t - phi(a * X^-t - b) * X^t // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) // = a + b * X^t + phi(a - b * X^t) - a.rotate_inplace(module, t); + a.rotate_inplace(module, t, scratch_1); } else { - a.rsh(module, 1); + a.rsh(module, 1, scratch); // a = a + phi(a) a.automorphism_add_inplace(module, auto_key, scratch); } @@ -481,7 +484,7 @@ fn combine( let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank); tmp_b.rotate(module, t, b); - tmp_b.rsh(module, 1); + tmp_b.rsh(module, 1, scratch_1); // a = (b* X^t - phi(b* X^t)) b.automorphism_sub_ba(module, &tmp_b, auto_key, scratch_1); diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs index 28af408..69a5586 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs @@ -6,11 +6,11 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, - TakeScalarZnx, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpPMatAlloc, - VmpPrepare, + ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, + TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, + VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, + VecZnxSwitchRing, VmpPMatAlloc, VmpPrepare, }, layouts::{Backend, Data, DataRef, Module, Scratch}, source::Source, @@ -51,14 +51,14 @@ pub struct CircuitBootstrappingKey { impl CircuitBootstrappingKeyEncryptSk for CircuitBootstrappingKey, BRA> where BlindRotationKey, BRA>: BlindRotationKeyAlloc + BlindRotationKeyEncryptSk, - Module: SvpApply - + IDFTTmpA + Module: SvpApplyDftToDft + + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxFillUniform + VecZnxSubABInplace @@ -68,7 +68,7 @@ where + VecZnxNormalize + VecZnxSub + SvpPrepare - + VecZnxSwithcDegree + + VecZnxSwitchRing + SvpPPolAllocBytes + SvpPPolAlloc + VecZnxAutomorphism, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs index 03cd21f..8e543eb 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs @@ -2,14 +2,15 @@ use std::time::Instant; use poulpy_hal::{ api::{ - DFT, IDFTConsume, IDFTTmpA, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApply, SvpApplyInplace, SvpPPolAlloc, - SvpPPolAllocBytes, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, - VecZnxAutomorphismInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, - VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, - VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxFillUniform, VecZnxNegateInplace, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, - VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, + ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, + SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxAutomorphismInplace, + VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigAutomorphismInplace, + VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAddInplace, + VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy, VecZnxFillUniform, VecZnxIdftApplyConsume, + VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, + VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, + ZnNormalizeInplace, }, layouts::{Backend, Module, ScalarZnx, ScratchOwned, ZnxView, ZnxViewMut}, oep::{ @@ -43,9 +44,9 @@ where + VecZnxNormalizeInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxSubABInplace + VecZnxAddInplace @@ -53,10 +54,10 @@ where + VecZnxSub + VecZnxAddScalarInplace + VecZnxAutomorphism - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxBigAllocBytes - + IDFTTmpA - + SvpApply + + VecZnxIdftApplyTmpA + + SvpApplyDftToDft + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigAlloc @@ -70,14 +71,15 @@ where + VmpApplyDftToDft + VmpApplyDftToDftAdd + SvpPPolAllocBytes - + VecZnxRotateInplace + + VecZnxRotateInplace + VecZnxBigAutomorphismInplace - + VecZnxRshInplace + + VecZnxRshInplace + VecZnxDftCopy + VecZnxNegateInplace + VecZnxCopy - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + + VecZnxRotateInplaceTmpBytes + VecZnxBigAllocBytes + VecZnxDftAddInplace + VecZnxRotate @@ -185,7 +187,12 @@ where // X^{data * 2^log_gap_out} let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); pt_ggsw.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(data * (1 << log_gap_out), &mut pt_ggsw.as_vec_znx_mut(), 0); + module.vec_znx_rotate_inplace( + data * (1 << log_gap_out), + &mut pt_ggsw.as_vec_znx_mut(), + 0, + scratch.borrow(), + ); res.print_noise(module, &sk_glwe_prepared, &pt_ggsw); @@ -224,9 +231,9 @@ where + VecZnxNormalizeInplace + VecZnxDftAllocBytes + VecZnxBigNormalize - + DFT - + SvpApplyInplace - + IDFTConsume + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + VecZnxNormalizeTmpBytes + VecZnxSubABInplace + VecZnxAddInplace @@ -234,10 +241,10 @@ where + VecZnxSub + VecZnxAddScalarInplace + VecZnxAutomorphism - + VecZnxSwithcDegree + + VecZnxSwitchRing + VecZnxBigAllocBytes - + IDFTTmpA - + SvpApply + + VecZnxIdftApplyTmpA + + SvpApplyDftToDft + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigAlloc @@ -251,13 +258,14 @@ where + VmpApplyDftToDft + VmpApplyDftToDftAdd + SvpPPolAllocBytes - + VecZnxRotateInplace + + VecZnxRotateInplace + VecZnxBigAutomorphismInplace - + VecZnxRshInplace + + VecZnxRotateInplaceTmpBytes + + VecZnxRshInplace + VecZnxDftCopy + VecZnxNegateInplace + VecZnxCopy - + VecZnxAutomorphismInplace + + VecZnxAutomorphismInplace + VecZnxBigSubSmallBInplace + VecZnxBigAllocBytes + VecZnxDftAddInplace diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs index 688e28e..3661f81 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/implementation/cpu_spqlios/fft64.rs @@ -1,4 +1,4 @@ -use poulpy_backend::cpu_spqlios::FFT64; +use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_hal::{api::ModuleNew, layouts::Module}; use crate::tfhe::{ @@ -10,12 +10,12 @@ use crate::tfhe::{ #[test] fn test_to_constant() { - let module: Module = Module::::new(256); - test_circuit_bootstrapping_to_constant::(&module); + let module: Module = Module::::new(256); + test_circuit_bootstrapping_to_constant::(&module); } #[test] fn test_to_exponent() { - let module: Module = Module::::new(256); - test_circuit_bootstrapping_to_exponent::(&module); + let module: Module = Module::::new(256); + test_circuit_bootstrapping_to_exponent::(&module); }