diff --git a/Cargo.toml b/Cargo.toml index a17e5f7..6f2a91e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] -members = ["base2k", "rlwe", "sampling", "utils"] - +members = ["base2k", "core", "sampling", "utils"] +resolver = "3" [workspace.dependencies] rug = "1.27" diff --git a/base2k/.vscode/settings.json b/base2k/.vscode/settings.json new file mode 100644 index 0000000..c38916e --- /dev/null +++ b/base2k/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "github.copilot.enable": { + "*": false, + "plaintext": false, + "markdown": false, + "scminput": false + }, + "files.associations": { + "random": "c" + } +} \ No newline at end of file diff --git a/base2k/Cargo.toml b/base2k/Cargo.toml index 2ebb8db..089cbde 100644 --- a/base2k/Cargo.toml +++ b/base2k/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "base2k" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] rug = {workspace = true} diff --git a/base2k/build.rs b/base2k/build.rs index 4ddb96c..f592b15 100644 --- a/base2k/build.rs +++ b/base2k/build.rs @@ -3,10 +3,11 @@ use std::path::absolute; fn main() { println!( "cargo:rustc-link-search=native={}", - absolute("./spqlios-arithmetic/build/spqlios") + absolute("spqlios-arithmetic/build/spqlios") .unwrap() .to_str() .unwrap() ); - println!("cargo:rustc-link-lib=static=spqlios"); //"cargo:rustc-link-lib=dylib=spqlios" + println!("cargo:rustc-link-lib=static=spqlios"); + // println!("cargo:rustc-link-lib=dylib=spqlios") } diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index f66a4d1..e73db89 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,7 @@ use base2k::{ - BACKEND, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, alloc_aligned, + AddNormal, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, }; use itertools::izip; use sampling::source::Source; @@ -8,89 +9,125 @@ use sampling::source::Source; fn main() { let n: usize = 16; let log_base2k: usize = 18; - let cols: usize = 3; - let msg_cols: usize = 2; - let log_scale: usize = msg_cols * log_base2k - 5; - let module: Module = Module::new(n, BACKEND::FFT64); + let ct_size: usize = 3; + let msg_size: usize = 2; + let log_scale: usize = msg_size * log_base2k - 5; + let module: Module = Module::::new(n); - let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); + let mut scratch: ScratchOwned = ScratchOwned::new(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = module.new_vec_znx(1, cols); - // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: Scalar = Scalar::new(n); - s.fill_ternary_prob(0.5, &mut source); + let mut s: ScalarZnx> = module.new_scalar_znx(1); + s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_ppol: SvpPPol = module.new_svp_ppol(); + let mut s_dft: ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(s.cols()); - // s_ppol <- DFT(s) - module.svp_prepare(&mut s_ppol, &s); + // s_dft <- DFT(s) + module.svp_prepare(&mut s_dft, 0, &s, 0); - // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = module.new_vec_znx(1, cols); - module.fill_uniform(log_base2k, &mut a, cols, &mut source); + // Allocates a VecZnx with two columns: ct=(0, 0) + let mut ct: VecZnx> = module.new_vec_znx( + 2, // Number of columns + ct_size, // Number of small poly per column + ); - // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.cols()); + // Fill the second column with random values: ct = (0, a) + ct.fill_uniform(log_base2k, 1, ct_size, &mut source); - // Applies buf_dft <- s * a - module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); + let mut buf_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_size); - // Alias scratch space - let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); + module.vec_znx_dft(&mut buf_dft, 0, &ct, 1); - // buf_big <- IDFT(buf_dft) (not normalized) - module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); + // Applies DFT(ct[1]) * DFT(s) + module.svp_apply_inplace( + &mut buf_dft, // DFT(ct[1] * s) + 0, // Selects the first column of res + &s_dft, // DFT(s) + 0, // Selects the first column of s_dft + ); - let mut m: VecZnx = module.new_vec_znx(1, msg_cols); + // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) + // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) + let mut buf_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_size); + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + + // Creates a plaintext: VecZnx with 1 column + let mut m = module.new_vec_znx( + 1, // Number of columns + msg_size, // Number of small polynomials + ); let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); - - // m m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); - m.normalize(log_base2k, &mut carry); + module.vec_znx_normalize_inplace(log_base2k, &mut m, 0, scratch.borrow()); - // buf_big <- m - buf_big - module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); - - // b <- normalize(buf_big) + e - let mut b: VecZnx = module.new_vec_znx(1, cols); - module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); - module.add_normal( - log_base2k, - &mut b, - log_base2k * cols, - &mut source, - 3.2, - 19.0, + // m - BIG(ct[1] * s) + module.vec_znx_big_sub_small_b_inplace( + &mut buf_big, + 0, // Selects the first column of the receiver + &m, + 0, // Selects the first column of the message ); - // Decrypt + // Normalizes back to VecZnx + // ct[0] <- m - BIG(c1 * s) + module.vec_znx_big_normalize( + log_base2k, + &mut ct, + 0, // Selects the first column of ct (ct[0]) + &buf_big, + 0, // Selects the first column of buf_big + scratch.borrow(), + ); - // buf_big <- a * s - module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); - module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); + // Add noise to ct[0] + // ct[0] <- ct[0] + e + ct.add_normal( + log_base2k, + 0, // Selects the first column of ct (ct[0]) + log_base2k * ct_size, // Scaling of the noise: 2^{-log_base2k * limbs} + &mut source, + 3.2, // Standard deviation + 19.0, // Truncatation bound + ); - // buf_big <- a * s + b - module.vec_znx_big_add_small_inplace(&mut buf_big, &b); + // Final ciphertext: ct = (-a * s + m + e, a) - // res <- normalize(buf_big) - module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); + // Decryption + + // DFT(ct[1] * s) + module.vec_znx_dft(&mut buf_dft, 0, &ct, 1); + module.svp_apply_inplace( + &mut buf_dft, + 0, // Selects the first column of res. + &s_dft, + 0, + ); + + // BIG(c1 * s) = IDFT(DFT(c1 * s)) + module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + + // BIG(c1 * s) + ct[0] + module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); + + // m + e <- BIG(ct[1] * s + ct[0]) + let mut res = module.new_vec_znx(1, ct_size); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch.borrow()); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.decode_vec_i64(0, log_base2k, res.cols() * log_base2k, &mut have); + res.decode_vec_i64(0, log_base2k, res.size() * log_base2k, &mut have); - let scale: f64 = (1 << (res.cols() * log_base2k - log_scale)) as f64; + let scale: f64 = (1 << (res.size() * log_base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) .enumerate() .for_each(|(i, (a, b))| { println!("{}: {} {}", i, a, (*b as f64) / scale); - }) + }); } diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs deleted file mode 100644 index a69c857..0000000 --- a/base2k/examples/vector_matrix_product.rs +++ /dev/null @@ -1,58 +0,0 @@ -use base2k::{ - BACKEND, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, - alloc_aligned, -}; - -fn main() { - let log_n: i32 = 5; - let n: usize = 1 << log_n; - - let module: Module = Module::new(n, BACKEND::FFT64); - let log_base2k: usize = 15; - let cols: usize = 5; - let log_k: usize = log_base2k * cols - 5; - - let rows: usize = cols; - let cols: usize = cols + 1; - - // Maximum size of the byte scratch needed - let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) | module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols); - - let mut buf: Vec = alloc_aligned(tmp_bytes); - - let mut a_values: Vec = vec![i64::default(); n]; - a_values[1] = (1 << log_base2k) + 1; - - let mut a: VecZnx = module.new_vec_znx(1, rows); - a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32); - a.normalize(log_base2k, &mut buf); - - a.print(0, a.cols(), n); - println!(); - - let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(1, rows, cols); - - (0..a.cols()).for_each(|row_i| { - let mut tmp: VecZnx = module.new_vec_znx(1, cols); - tmp.at_mut(row_i)[1] = 1 as i64; - module.vmp_prepare_row(&mut vmp_pmat, tmp.raw(), row_i, &mut buf); - }); - - let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); - module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); - - let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); - module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); - - let mut res: VecZnx = module.new_vec_znx(1, rows); - module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf); - - let mut values_res: Vec = vec![i64::default(); n]; - res.decode_vec_i64(0, log_base2k, log_k, &mut values_res); - - res.print(0, res.cols(), n); - - module.free(); - - println!("{:?}", values_res) -} diff --git a/base2k/spqlios-arithmetic b/base2k/spqlios-arithmetic index e3d3247..b919282 160000 --- a/base2k/spqlios-arithmetic +++ b/base2k/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit e3d3247335faccf2b6361213c354cd61b958325e +Subproject commit b919282c9b913e8b11418df6afdb0baa02debc9b diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index c8c08e9..45214c6 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,6 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{Infos, VecZnx}; +use crate::znx_base::{ZnxView, ZnxViewMut}; +use crate::{VecZnx, znx_base::ZnxInfos}; use itertools::izip; use rug::{Assign, Float}; use std::cmp::min; @@ -9,129 +10,141 @@ pub trait Encoding { /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `data`: data to encode on the receiver. /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); - - /// decode a vector of i64 from the receiver. - /// - /// # Arguments - /// - /// * `poly_idx`: the index of the poly where to encode the data. - /// * `log_base2k`: base two negative logarithm decomposition of the receiver. - /// * `log_k`: base two logarithm of the scaling of the data. - /// * `data`: data to decode from the receiver. - fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]); - - /// decode a vector of Float from the receiver. - /// - /// # Arguments - /// * `poly_idx`: the index of the poly where to encode the data. - /// * `log_base2k`: base two negative logarithm decomposition of the receiver. - /// * `data`: data to decode from the receiver. - fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]); + fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize); /// encodes a single i64 on the receiver at the given index. /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `i`: index of the coefficient on which to encode the data. /// * `data`: data to encode on the receiver. /// * `log_max`: base two logarithm of the infinity norm of the input data. - fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize); +} + +pub trait Decoding { + /// decode a vector of i64 from the receiver. + /// + /// # Arguments + /// + /// * `col_i`: the index of the poly where to encode the data. + /// * `log_base2k`: base two negative logarithm decomposition of the receiver. + /// * `log_k`: base two logarithm of the scaling of the data. + /// * `data`: data to decode from the receiver. + fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]); + + /// decode a vector of Float from the receiver. + /// + /// # Arguments + /// * `col_i`: the index of the poly where to encode the data. + /// * `log_base2k`: base two negative logarithm decomposition of the receiver. + /// * `data`: data to decode from the receiver. + fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]); /// decode a single of i64 from the receiver at the given index. /// /// # Arguments /// - /// * `poly_idx`: the index of the poly where to encode the data. + /// * `col_i`: the index of the poly where to encode the data. /// * `log_base2k`: base two negative logarithm decomposition of the receiver. /// * `log_k`: base two negative logarithm of the scaling of the data. /// * `i`: index of the coefficient to decode. /// * `data`: data to decode from the receiver. - fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64; + fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64; } -impl Encoding for VecZnx { - fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - encode_vec_i64(self, poly_idx, log_base2k, log_k, data, log_max) +impl + AsRef<[u8]>> Encoding for VecZnx { + fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { + encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max) } - fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { - decode_vec_i64(self, poly_idx, log_base2k, log_k, data) - } - - fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]) { - decode_vec_float(self, poly_idx, log_base2k, data) - } - - fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - encode_coeff_i64(self, poly_idx, log_base2k, log_k, i, value, log_max) - } - - fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { - decode_coeff_i64(self, poly_idx, log_base2k, log_k, i) + fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { + encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max) } } -fn encode_vec_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; +impl> Decoding for VecZnx { + fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { + decode_vec_i64(self, col_i, log_base2k, log_k, data) + } + + fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]) { + decode_vec_float(self, col_i, log_base2k, data) + } + + fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { + decode_coeff_i64(self, col_i, log_base2k, log_k, i) + } +} + +fn encode_vec_i64 + AsRef<[u8]>>( + a: &mut VecZnx, + col_i: usize, + log_base2k: usize, + log_k: usize, + data: &[i64], + log_max: usize, +) { + let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( - cols <= a.cols(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", - cols, - a.cols() + size <= a.size(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}", + size, + a.size() ); - assert!(poly_idx < a.size); + assert!(col_i < a.cols()); assert!(data.len() <= a.n()) } let data_len: usize = data.len(); let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.cols()).for_each(|i| unsafe { - znx_zero_i64_ref(a.n() as u64, a.at_poly_mut_ptr(poly_idx, i)); + // Zeroes coefficients of the i-th column + (0..a.size()).for_each(|i| unsafe { + znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i)); }); // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(poly_idx, cols - 1)[..data_len].copy_from_slice(&data[..data_len]); + a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]); } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols) + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size) .rev() .enumerate() .for_each(|(i, i_rev)| { let shift: usize = i * log_base2k; - izip!(a.at_poly_mut(poly_idx, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); + izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask); }) } // Case where self.prec % self.k != 0. if log_k_rem != log_base2k { - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols).rev().for_each(|i| { - a.at_poly_mut(poly_idx, i)[..data_len] + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size).rev().for_each(|i| { + a.at_mut(col_i, i)[..data_len] .iter_mut() .for_each(|x| *x <<= log_k_rem); }) } } -fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; +fn decode_vec_i64>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { + let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( @@ -140,26 +153,26 @@ fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data.len(), a.n() ); - assert!(poly_idx < a.size()); + assert!(col_i < a.cols()); } - data.copy_from_slice(a.at_poly(poly_idx, 0)); + data.copy_from_slice(a.at(col_i, 0)); let rem: usize = log_base2k - (log_k % log_base2k); - (1..cols).for_each(|i| { - if i == cols - 1 && rem != log_base2k { + (1..size).for_each(|i| { + if i == size - 1 && rem != log_base2k { let k_rem: usize = log_base2k - rem; - izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { - izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << log_base2k) + x; }); } }) } -fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [Float]) { - let cols: usize = a.cols(); +fn decode_vec_float>(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { + let size: usize = a.size(); #[cfg(debug_assertions)] { assert!( @@ -168,23 +181,23 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [ data.len(), a.n() ); - assert!(poly_idx < a.size()); + assert!(col_i < a.cols()); } - let prec: u32 = (log_base2k * cols) as u32; + let prec: u32 = (log_base2k * size) as u32; // 2^{log_base2k} let base = Float::with_val(prec, (1 << log_base2k) as f64); // y[i] = sum x[j][i] * 2^{-log_base2k*j} - (0..cols).for_each(|i| { + (0..size).for_each(|i| { if i == 0 { - izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { y.assign(*x); *y /= &base; }); } else { - izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { *y += Float::with_val(prec, *x); *y /= &base; }); @@ -192,54 +205,62 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [ }); } -fn encode_coeff_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; +fn encode_coeff_i64 + AsRef<[u8]>>( + a: &mut VecZnx, + col_i: usize, + log_base2k: usize, + log_k: usize, + i: usize, + value: i64, + log_max: usize, +) { + let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!(i < a.n()); assert!( - cols <= a.cols(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", - cols, - a.cols() + size <= a.size(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}", + size, + a.size() ); - assert!(poly_idx < a.size()); + assert!(col_i < a.cols()); } let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.cols()).for_each(|j| a.at_poly_mut(poly_idx, j)[i] = 0); + (0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0); // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(poly_idx, cols - 1)[i] = value; + a.at_mut(col_i, size - 1)[i] = value; } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols) + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size) .rev() .enumerate() .for_each(|(j, j_rev)| { - a.at_poly_mut(poly_idx, j_rev)[i] = (value >> (j * log_base2k)) & mask; + a.at_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask; }) } // Case where prec % k != 0. if log_k_rem != log_base2k { - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols).rev().for_each(|j| { - a.at_poly_mut(poly_idx, j)[i] <<= log_k_rem; + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size).rev().for_each(|j| { + a.at_mut(col_i, j)[i] <<= log_k_rem; }) } } -fn decode_coeff_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { +fn decode_coeff_i64>(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 { #[cfg(debug_assertions)] { assert!(i < a.n()); - assert!(poly_idx < a.size()) + assert!(col_i < a.cols()) } let cols: usize = (log_k + log_base2k - 1) / log_base2k; @@ -261,27 +282,30 @@ fn decode_coeff_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize #[cfg(test)] mod tests { - use crate::{Encoding, Infos, VecZnx}; + use crate::vec_znx_ops::*; + use crate::znx_base::*; + use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos}; use itertools::izip; use sampling::source::Source; #[test] fn test_set_get_i64_lo_norm() { let n: usize = 8; + let module: Module = Module::::new(n); let log_base2k: usize = 17; - let cols: usize = 5; - let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, 2, cols); + let size: usize = 5; + let log_k: usize = size * log_base2k - 5; + let mut a: VecZnx<_> = module.new_vec_znx(2, size); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.size()).for_each(|poly_idx| { + (0..a.cols()).for_each(|col_i| { let mut have: Vec = vec![i64::default(); n]; have.iter_mut() .for_each(|x| *x = (source.next_i64() << 56) >> 56); - a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 10); + a.encode_vec_i64(col_i, log_base2k, log_k, &have, 10); let mut want: Vec = vec![i64::default(); n]; - a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want); + a.decode_vec_i64(col_i, log_base2k, log_k, &mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); }); } @@ -289,19 +313,20 @@ mod tests { #[test] fn test_set_get_i64_hi_norm() { let n: usize = 8; + let module: Module = Module::::new(n); let log_base2k: usize = 17; - let cols: usize = 5; - let log_k: usize = cols * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(n, 2, cols); + let size: usize = 5; + let log_k: usize = size * log_base2k - 5; + let mut a: VecZnx<_> = module.new_vec_znx(2, size); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.size()).for_each(|poly_idx| { + (0..a.cols()).for_each(|col_i| { let mut have: Vec = vec![i64::default(); n]; have.iter_mut().for_each(|x| *x = source.next_i64()); - a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 64); + a.encode_vec_i64(col_i, log_base2k, log_k, &have, 64); let mut want = vec![i64::default(); n]; - a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want); + a.decode_vec_i64(col_i, log_base2k, log_k, &mut want); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); }) } diff --git a/base2k/src/ffi/module.rs b/base2k/src/ffi/module.rs index 755d613..e35d4c0 100644 --- a/base2k/src/ffi/module.rs +++ b/base2k/src/ffi/module.rs @@ -3,8 +3,6 @@ pub struct module_info_t { } pub type module_type_t = ::std::os::raw::c_uint; -pub const module_type_t_FFT64: module_type_t = 0; -pub const module_type_t_NTT120: module_type_t = 1; pub use self::module_type_t as MODULE_TYPE; pub type MODULE = module_info_t; diff --git a/base2k/src/ffi/svp.rs b/base2k/src/ffi/svp.rs index 71c871d..08b2da1 100644 --- a/base2k/src/ffi/svp.rs +++ b/base2k/src/ffi/svp.rs @@ -33,3 +33,16 @@ unsafe extern "C" { a_sl: u64, ); } + +unsafe extern "C" { + pub unsafe fn svp_apply_dft_to_dft( + module: *const MODULE, + res: *const VEC_ZNX_DFT, + res_size: u64, + res_cols: u64, + ppol: *const SVP_PPOL, + a: *const VEC_ZNX_DFT, + a_size: u64, + a_cols: u64, + ); +} diff --git a/base2k/src/ffi/vec_znx_big.rs b/base2k/src/ffi/vec_znx_big.rs index e1222c3..8c06e90 100644 --- a/base2k/src/ffi/vec_znx_big.rs +++ b/base2k/src/ffi/vec_znx_big.rs @@ -8,17 +8,17 @@ pub struct vec_znx_big_t { pub type VEC_ZNX_BIG = vec_znx_big_t; unsafe extern "C" { - pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64; + pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64; } unsafe extern "C" { - pub fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG; + pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG; } unsafe extern "C" { - pub fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG); + pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG); } unsafe extern "C" { - pub fn vec_znx_big_add( + pub unsafe fn vec_znx_big_add( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -29,7 +29,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_add_small( + pub unsafe fn vec_znx_big_add_small( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -41,7 +41,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_add_small2( + pub unsafe fn vec_znx_big_add_small2( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -54,7 +54,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub( + pub unsafe fn vec_znx_big_sub( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -65,7 +65,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub_small_b( + pub unsafe fn vec_znx_big_sub_small_b( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -77,7 +77,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub_small_a( + pub unsafe fn vec_znx_big_sub_small_a( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -89,7 +89,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub fn vec_znx_big_sub_small2( + pub unsafe fn vec_znx_big_sub_small2( module: *const MODULE, res: *mut VEC_ZNX_BIG, res_size: u64, @@ -101,8 +101,13 @@ unsafe extern "C" { b_sl: u64, ); } + unsafe extern "C" { - pub fn vec_znx_big_normalize_base2k( + pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_normalize_base2k( module: *const MODULE, log2_base2k: u64, res: *mut i64, @@ -113,34 +118,9 @@ unsafe extern "C" { tmp_space: *mut u8, ); } -unsafe extern "C" { - pub fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; -} unsafe extern "C" { - pub fn vec_znx_big_automorphism( - module: *const MODULE, - p: i64, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - ); -} - -unsafe extern "C" { - pub fn vec_znx_big_rotate( - module: *const MODULE, - p: i64, - res: *mut VEC_ZNX_BIG, - res_size: u64, - a: *const VEC_ZNX_BIG, - a_size: u64, - ); -} - -unsafe extern "C" { - pub fn vec_znx_big_range_normalize_base2k( + pub unsafe fn vec_znx_big_range_normalize_base2k( module: *const MODULE, log2_base2k: u64, res: *mut i64, @@ -153,6 +133,29 @@ unsafe extern "C" { tmp_space: *mut u8, ); } + unsafe extern "C" { - pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; + pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_automorphism( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); +} + +unsafe extern "C" { + pub unsafe fn vec_znx_big_rotate( + module: *const MODULE, + p: i64, + res: *mut VEC_ZNX_BIG, + res_size: u64, + a: *const VEC_ZNX_BIG, + a_size: u64, + ); } diff --git a/base2k/src/infos.rs b/base2k/src/infos.rs deleted file mode 100644 index 08472d9..0000000 --- a/base2k/src/infos.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::LAYOUT; - -pub trait Infos { - /// Returns the ring degree of the receiver. - fn n(&self) -> usize; - - /// Returns the base two logarithm of the ring dimension of the receiver. - fn log_n(&self) -> usize; - - /// Returns the number of stacked polynomials. - fn size(&self) -> usize; - - /// Returns the memory layout of the stacked polynomials. - fn layout(&self) -> LAYOUT; - - /// Returns the number of columns of the receiver. - /// This method is equivalent to [Infos::cols]. - fn cols(&self) -> usize; - - /// Returns the number of rows of the receiver. - fn rows(&self) -> usize; -} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 7e97b00..89a52ef 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -2,39 +2,43 @@ pub mod encoding; #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] // Other modules and exports pub mod ffi; -pub mod infos; +pub mod mat_znx_dft; +pub mod mat_znx_dft_ops; pub mod module; pub mod sampling; +pub mod scalar_znx; +pub mod scalar_znx_dft; +pub mod scalar_znx_dft_ops; pub mod stats; -pub mod svp; pub mod vec_znx; pub mod vec_znx_big; +pub mod vec_znx_big_ops; pub mod vec_znx_dft; -pub mod vmp; +pub mod vec_znx_dft_ops; +pub mod vec_znx_ops; +pub mod znx_base; pub use encoding::*; -pub use infos::*; +pub use mat_znx_dft::*; +pub use mat_znx_dft_ops::*; pub use module::*; pub use sampling::*; -#[allow(unused_imports)] +pub use scalar_znx::*; +pub use scalar_znx_dft::*; +pub use scalar_znx_dft_ops::*; pub use stats::*; -pub use svp::*; pub use vec_znx::*; pub use vec_znx_big::*; +pub use vec_znx_big_ops::*; pub use vec_znx_dft::*; -pub use vmp::*; +pub use vec_znx_dft_ops::*; +pub use vec_znx_ops::*; +pub use znx_base::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum LAYOUT { - ROW, - COL, -} - -pub fn is_aligned_custom(ptr: *const T, align: usize) -> bool { +fn is_aligned_custom(ptr: *const T, align: usize) -> bool { (ptr as usize) % align == 0 } @@ -51,38 +55,35 @@ pub fn assert_alignement(ptr: *const T) { pub fn cast(data: &[T]) -> &[V] { let ptr: *const V = data.as_ptr() as *const V; - let len: usize = data.len() / std::mem::size_of::(); + let len: usize = data.len() / size_of::(); unsafe { std::slice::from_raw_parts(ptr, len) } } pub fn cast_mut(data: &[T]) -> &mut [V] { let ptr: *mut V = data.as_ptr() as *mut V; - let len: usize = data.len() / std::mem::size_of::(); + let len: usize = data.len() / size_of::(); unsafe { std::slice::from_raw_parts_mut(ptr, len) } } -use std::alloc::{Layout, alloc}; -use std::ptr; - /// Allocates a block of bytes with a custom alignement. /// Alignement must be a power of two and size a multiple of the alignement. /// Allocated memory is initialized to zero. -pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { +fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { assert!( align.is_power_of_two(), "Alignment must be a power of two but is {}", align ); assert_eq!( - (size * std::mem::size_of::()) % align, + (size * size_of::()) % align, 0, "size={} must be a multiple of align={}", size, align ); unsafe { - let layout: Layout = Layout::from_size_align(size, align).expect("Invalid alignment"); - let ptr: *mut u8 = alloc(layout); + let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment"); + let ptr: *mut u8 = std::alloc::alloc(layout); if ptr.is_null() { panic!("Memory allocation failed"); } @@ -93,36 +94,158 @@ pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { align ); // Init allocated memory to zero - ptr::write_bytes(ptr, 0, size); + std::ptr::write_bytes(ptr, 0, size); Vec::from_raw_parts(ptr, size, size) } } -/// Allocates a block of bytes aligned with [DEFAULTALIGN]. -/// Size must be amultiple of [DEFAULTALIGN]. -/// /// Allocated memory is initialized to zero. -pub fn alloc_aligned_u8(size: usize) -> Vec { - alloc_aligned_custom_u8(size, DEFAULTALIGN) -} - /// Allocates a block of T aligned with [DEFAULTALIGN]. /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { assert_eq!( - (size * std::mem::size_of::()) % align, + (size * size_of::()) % align, 0, "size={} must be a multiple of align={}", size, align ); - let mut vec_u8: Vec = alloc_aligned_custom_u8(std::mem::size_of::() * size, align); + let mut vec_u8: Vec = alloc_aligned_custom_u8(size_of::() * size, align); let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T; - let len: usize = vec_u8.len() / std::mem::size_of::(); - let cap: usize = vec_u8.capacity() / std::mem::size_of::(); + let len: usize = vec_u8.len() / size_of::(); + let cap: usize = vec_u8.capacity() / size_of::(); std::mem::forget(vec_u8); unsafe { Vec::from_raw_parts(ptr, len, cap) } } +/// Allocates an aligned vector of size equal to the smallest multiple +/// of [DEFAULTALIGN]/size_of::() that is equal or greater to `size`. pub fn alloc_aligned(size: usize) -> Vec { - alloc_aligned_custom::(size, DEFAULTALIGN) + alloc_aligned_custom::( + size + (size % (DEFAULTALIGN / size_of::())), + DEFAULTALIGN, + ) +} + +// Scratch implementation below + +pub struct ScratchOwned(Vec); + +impl ScratchOwned { + pub fn new(byte_count: usize) -> Self { + let data: Vec = alloc_aligned(byte_count); + Self(data) + } + + pub fn borrow(&mut self) -> &mut Scratch { + Scratch::new(&mut self.0) + } +} + +pub struct Scratch { + data: [u8], +} + +impl Scratch { + fn new(data: &mut [u8]) -> &mut Self { + unsafe { &mut *(data as *mut [u8] as *mut Self) } + } + + pub fn available(&self) -> usize { + let ptr: *const u8 = self.data.as_ptr(); + let self_len: usize = self.data.len(); + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + self_len.saturating_sub(aligned_offset) + } + + fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { + let ptr: *mut u8 = data.as_mut_ptr(); + let self_len: usize = data.len(); + + let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN); + let aligned_len: usize = self_len.saturating_sub(aligned_offset); + + if let Some(rem_len) = aligned_len.checked_sub(take_len) { + unsafe { + let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len); + let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + + let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); + + return (take_slice, rem_slice); + } + } else { + panic!( + "Attempted to take {} from scratch with {} aligned bytes left", + take_len, + aligned_len, + // type_name::(), + // aligned_len + ); + } + } + + pub fn tmp_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::()); + + unsafe { + ( + &mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)), + Self::new(rem_slice), + ) + } + } + + pub fn tmp_scalar_znx(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols)); + + ( + ScalarZnx::from_data(take_slice, module.n(), cols), + Self::new(rem_slice), + ) + } + + pub fn tmp_scalar_znx_dft(&mut self, module: &Module, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols)); + + ( + ScalarZnxDft::from_data(take_slice, module.n(), cols), + Self::new(rem_slice), + ) + } + + pub fn tmp_vec_znx_dft( + &mut self, + module: &Module, + cols: usize, + size: usize, + ) -> (VecZnxDft<&mut [u8], B>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size)); + + ( + VecZnxDft::from_data(take_slice, module.n(), cols, size), + Self::new(rem_slice), + ) + } + + pub fn tmp_vec_znx_big( + &mut self, + module: &Module, + cols: usize, + size: usize, + ) -> (VecZnxBig<&mut [u8], B>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size)); + + ( + VecZnxBig::from_data(take_slice, module.n(), cols, size), + Self::new(rem_slice), + ) + } + + pub fn tmp_vec_znx(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size)); + ( + VecZnx::from_data(take_slice, module.n(), cols, size), + Self::new(rem_slice), + ) + } } diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs new file mode 100644 index 0000000..209c696 --- /dev/null +++ b/base2k/src/mat_znx_dft.rs @@ -0,0 +1,232 @@ +use crate::znx_base::ZnxInfos; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use std::marker::PhantomData; + +/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], +/// stored as a 3D matrix in the DFT domain in a single contiguous array. +/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. +/// +/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. +/// See the trait [MatZnxDftOps] for additional information. +pub struct MatZnxDft { + data: D, + n: usize, + size: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + _phantom: PhantomData, +} + +impl ZnxInfos for MatZnxDft { + fn cols(&self) -> usize { + self.cols_in + } + + fn rows(&self) -> usize { + self.rows + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl ZnxSliceSize for MatZnxDft { + fn sl(&self) -> usize { + self.n() * self.cols_out() + } +} + +impl DataView for MatZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for MatZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for MatZnxDft { + type Scalar = f64; +} + +impl MatZnxDft { + pub fn cols_in(&self) -> usize { + self.cols_in + } + + pub fn cols_out(&self) -> usize { + self.cols_out + } +} + +impl>, B: Backend> MatZnxDft { + pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + unsafe { + crate::ffi::vmp::bytes_of_vmp_pmat( + module.ptr, + (rows * cols_in) as u64, + (size * cols_out) as u64, + ) as usize + } + } + + pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } + + pub(crate) fn new_from_bytes( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: impl Into>, + ) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } +} + +impl> MatZnxDft { + /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. + /// + /// # Arguments + /// + /// * `row`: row index (i). + /// * `col`: col index (j). + #[allow(dead_code)] + fn at(&self, row: usize, col: usize) -> Vec { + let n: usize = self.n(); + + let mut res: Vec = alloc_aligned(n); + + if n < 8 { + res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); + } else { + (0..n >> 3).for_each(|blk| { + res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); + }); + } + + res + } + + #[allow(dead_code)] + fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { + let nrows: usize = self.rows(); + let nsize: usize = self.size(); + if col == (nsize - 1) && (nsize & 1 == 1) { + &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] + } else { + &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] + } + } +} + +pub type MatZnxDftOwned = MatZnxDft, B>; + +pub trait MatZnxDftToRef { + fn to_ref(&self) -> MatZnxDft<&[u8], B>; +} + +pub trait MatZnxDftToMut { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; +} + +impl MatZnxDftToMut for MatZnxDft, B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft, B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data.as_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToMut for MatZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&mut [u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&[u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs new file mode 100644 index 0000000..24be2e2 --- /dev/null +++ b/base2k/src/mat_znx_dft_ops.rs @@ -0,0 +1,487 @@ +use crate::ffi::vec_znx_dft::vec_znx_dft_t; +use crate::ffi::vmp; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{ + Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, + VecZnxDftToRef, +}; + +pub trait MatZnxDftAlloc { + /// Allocates a new [MatZnxDft] with the given number of rows and columns. + /// + /// # Arguments + /// + /// * `rows`: number of rows (number of [VecZnxDft]). + /// * `size`: number of size (number of size of each [VecZnxDft]). + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned; + + fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; + + fn new_mat_znx_dft_from_bytes( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> MatZnxDftOwned; +} + +pub trait MatZnxDftScratch { + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. + fn vmp_apply_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; +} + +/// This trait implements methods for vector matrix product, +/// that is, multiplying a [VecZnx] with a [MatZnxDft]. +pub trait MatZnxDftOps { + /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `b`: [MatZnxDft] on which the values are encoded. + /// * `a`: the [VecZnxDft] to encode on the [MatZnxDft]. + /// * `row_i`: the index of the row to prepare. + /// + /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. + fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + where + R: MatZnxDftToMut, + A: VecZnxDftToRef; + + /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. + /// * `a`: [MatZnxDft] on which the values are encoded. + /// * `row_i`: the index of the row to extract. + fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + where + R: VecZnxDftToMut, + A: MatZnxDftToRef; + + /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. + /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. + /// + /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] + /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) + /// and each vector a [VecZnxDft] (row) of the [MatZnxDft]. + /// + /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and + /// `j` size, the output is a [VecZnx] of `j` size. + /// + /// If there is a mismatch between the dimensions the largest valid ones are used. + /// + /// ```text + /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| + /// |h i j| + /// |k l m| + /// ``` + /// where each element is a [VecZnxDft]. + /// + /// # Arguments + /// + /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. + /// * `a`: the left operand [VecZnxDft] of the vector matrix product. + /// * `b`: the right operand [MatZnxDft] of the vector matrix product. + /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. + fn vmp_apply(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef; + + // Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R. + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef; +} + +impl MatZnxDftAlloc for Module { + fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + MatZnxDftOwned::bytes_of(self, rows, cols_in, cols_out, size) + } + + fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned { + MatZnxDftOwned::new(self, rows, cols_in, cols_out, size) + } + + fn new_mat_znx_dft_from_bytes( + &self, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: Vec, + ) -> MatZnxDftOwned { + MatZnxDftOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes) + } +} + +impl MatZnxDftScratch for Module { + fn vmp_apply_tmp_bytes( + &self, + res_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + unsafe { + vmp::vmp_apply_dft_to_dft_tmp_bytes( + self.ptr, + (res_size * b_cols_out) as u64, + (a_size * b_cols_in) as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, + ) as usize + } + } +} + +impl MatZnxDftOps for Module { + fn vmp_prepare_row(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) + where + R: MatZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res: MatZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + a.cols(), + res.cols_out(), + "a.cols(): {} != res.cols_out(): {}", + a.cols(), + res.cols_out() + ); + assert!( + res_row < res.rows(), + "res_row: {} >= res.rows(): {}", + res_row, + res.rows() + ); + assert!( + res_col_in < res.cols_in(), + "res_col_in: {} >= res.cols_in(): {}", + res_col_in, + res.cols_in() + ); + assert_eq!( + res.size(), + a.size(), + "res.size(): {} != a.size(): {}", + res.size(), + a.size() + ); + } + + unsafe { + vmp::vmp_prepare_row_dft( + self.ptr, + res.as_mut_ptr() as *mut vmp::vmp_pmat_t, + a.as_ptr() as *const vec_znx_dft_t, + (res_row * res.cols_in() + res_col_in) as u64, + (res.rows() * res.cols_in()) as u64, + (res.size() * res.cols_out()) as u64, + ); + } + } + + fn vmp_extract_row(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) + where + R: VecZnxDftToMut, + A: MatZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: MatZnxDft<&[u8], _> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + res.cols(), + a.cols_out(), + "res.cols(): {} != a.cols_out(): {}", + res.cols(), + a.cols_out() + ); + assert!( + a_row < a.rows(), + "a_row: {} >= a.rows(): {}", + a_row, + a.rows() + ); + assert!( + a_col_in < a.cols_in(), + "a_col_in: {} >= a.cols_in(): {}", + a_col_in, + a.cols_in() + ); + assert_eq!( + res.size(), + a.size(), + "res.size(): {} != a.size(): {}", + res.size(), + a.size() + ); + } + unsafe { + vmp::vmp_extract_row_dft( + self.ptr, + res.as_mut_ptr() as *mut vec_znx_dft_t, + a.as_ptr() as *const vmp::vmp_pmat_t, + (a_row * a.cols_in() + a_col_in) as u64, + (a.rows() * a.cols_in()) as u64, + (a.size() * a.cols_out()) as u64, + ); + } + } + + fn vmp_apply(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + let b: MatZnxDft<&[u8], _> = b.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + res.cols(), + b.cols_out(), + "res.cols(): {} != b.cols_out: {}", + res.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + } + + let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes( + res.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { + vmp::vmp_apply_dft_to_dft( + self.ptr, + res.as_mut_ptr() as *mut vec_znx_dft_t, + (res.size() * res.cols()) as u64, + a.as_ptr() as *const vec_znx_dft_t, + (a.size() * a.cols()) as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } + + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + B: MatZnxDftToRef { + let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); + let a: VecZnxDft<&[u8], _> = a.to_ref(); + let b: MatZnxDft<&[u8], _> = b.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + res.cols(), + b.cols_out(), + "res.cols(): {} != b.cols_out: {}", + res.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + } + + let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes( + res.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { + vmp::vmp_apply_dft_to_dft_add( + self.ptr, + res.as_mut_ptr() as *mut vec_znx_dft_t, + (res.size() * res.cols()) as u64, + a.as_ptr() as *const vec_znx_dft_t, + (a.size() * a.cols()) as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + (b.rows() * b.cols_in()) as u64, + (b.size() * b.cols_out()) as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } +} +#[cfg(test)] +mod tests { + use crate::{ + Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, + VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, + }; + use sampling::source::Source; + + use super::{MatZnxDftAlloc, MatZnxDftScratch}; + + #[test] + fn vmp_prepare_row() { + let module: Module = Module::::new(16); + let log_base2k: usize = 8; + let mat_rows: usize = 4; + let mat_cols_in: usize = 2; + let mat_cols_out: usize = 2; + let mat_size: usize = 5; + let mut a: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); + let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut b_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut mat: MatZnxDft, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + + for col_in in 0..mat_cols_in { + for row_i in 0..mat_rows { + let mut source: Source = Source::new([0u8; 32]); + (0..mat_cols_out).for_each(|col_out| { + a.fill_uniform(log_base2k, col_out, mat_size, &mut source); + module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); + }); + module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft); + module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in); + assert_eq!(a_dft.raw(), b_dft.raw()); + } + } + } + + #[test] + fn vmp_apply() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let log_base2k: usize = 15; + let a_size: usize = 5; + let mat_size: usize = 6; + let res_size: usize = 5; + + [1, 2].iter().for_each(|in_cols| { + [1, 2].iter().for_each(|out_cols| { + let a_cols: usize = *in_cols; + let res_cols: usize = *out_cols; + + let mat_rows: usize = a_size; + let mat_cols_in: usize = a_cols; + let mat_cols_out: usize = res_cols; + let res_cols: usize = mat_cols_out; + + let mut scratch: ScratchOwned = ScratchOwned::new( + module.vmp_apply_tmp_bytes( + res_size, + a_size, + mat_rows, + mat_cols_in, + mat_cols_out, + mat_size, + ) | module.vec_znx_big_normalize_tmp_bytes(), + ); + + let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); + + (0..a_cols).for_each(|i| { + a.at_mut(i, 2)[i + 1] = 1; + }); + + let mut mat_znx_dft: MatZnxDft, FFT64> = + module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + + let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); + + let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); + + // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. + (0..a.size()).for_each(|row_i| { + (0..mat_cols_in).for_each(|col_in_i| { + (0..mat_cols_out).for_each(|col_out_i| { + let idx = 1 + col_in_i * mat_cols_out + col_out_i; + tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} + module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i); + tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; + }); + module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + }); + }); + + let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); + (0..a_cols).for_each(|i| { + module.vec_znx_dft(&mut a_dft, i, &a, i); + }); + + module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow()); + + let mut res_have_vi64: Vec = vec![i64::default(); n]; + + let mut res_have: VecZnx> = module.new_vec_znx(res_cols, res_size); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(log_base2k, &mut res_have, i, &c_big, i, scratch.borrow()); + }); + + (0..mat_cols_out).for_each(|col_i| { + let mut res_want_vi64: Vec = vec![i64::default(); n]; + (0..a_cols).for_each(|i| { + res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; + }); + res_have.decode_vec_i64(col_i, log_base2k, log_base2k * 3, &mut res_have_vi64); + assert_eq!(res_have_vi64, res_want_vi64); + }); + }); + }); + } +} diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 8cbdbca..f6d0e0e 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,5 +1,6 @@ use crate::GALOISGENERATOR; use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; +use std::marker::PhantomData; #[derive(Copy, Clone)] #[repr(u8)] @@ -8,37 +9,50 @@ pub enum BACKEND { NTT120, } -pub struct Module { - pub ptr: *mut MODULE, - pub n: usize, - pub backend: BACKEND, +pub trait Backend { + const KIND: BACKEND; + fn module_type() -> u32; } -impl Module { +pub struct FFT64; +pub struct NTT120; + +impl Backend for FFT64 { + const KIND: BACKEND = BACKEND::FFT64; + fn module_type() -> u32 { + 0 + } +} + +impl Backend for NTT120 { + const KIND: BACKEND = BACKEND::NTT120; + fn module_type() -> u32 { + 1 + } +} + +pub struct Module { + pub ptr: *mut MODULE, + n: usize, + _marker: PhantomData, +} + +impl Module { // Instantiates a new module. - pub fn new(n: usize, module_type: BACKEND) -> Self { + pub fn new(n: usize) -> Self { unsafe { - let module_type_u32: u32; - match module_type { - BACKEND::FFT64 => module_type_u32 = 0, - BACKEND::NTT120 => module_type_u32 = 1, - } - let m: *mut module_info_t = new_module_info(n as u64, module_type_u32); + let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); if m.is_null() { panic!("Failed to create module."); } Self { ptr: m, n: n, - backend: module_type, + _marker: PhantomData, } } } - pub fn backend(&self) -> BACKEND { - self.backend - } - pub fn n(&self) -> usize { self.n } @@ -51,26 +65,27 @@ impl Module { (self.n() << 1) as _ } - // Returns GALOISGENERATOR^|gen| * sign(gen) - pub fn galois_element(&self, gen: i64) -> i64 { - if gen == 0 { + // Returns GALOISGENERATOR^|generator| * sign(generator) + pub fn galois_element(&self, generator: i64) -> i64 { + if generator == 0 { return 1; } - ((mod_exp_u64(GALOISGENERATOR, gen.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * gen.signum() + ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() } // Returns gen^-1 - pub fn galois_element_inv(&self, gen: i64) -> i64 { - if gen == 0 { + pub fn galois_element_inv(&self, gal_el: i64) -> i64 { + if gal_el == 0 { panic!("cannot invert 0") } - ((mod_exp_u64(gen.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64) - * gen.signum() + ((mod_exp_u64(gal_el.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64) + * gal_el.signum() } +} - pub fn free(self) { +impl Drop for Module { + fn drop(&mut self) { unsafe { delete_module_info(self.ptr) } - drop(self); } } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 064c1e2..b4e1489 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -1,56 +1,132 @@ -use crate::{Infos, Module, VecZnx}; +use crate::znx_base::ZnxViewMut; +use crate::{FFT64, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxToMut}; use rand_distr::{Distribution, Normal}; use sampling::source::Source; -pub trait Sampling { - /// Fills the first `cols` cols with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source); +pub trait FillUniform { + /// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] + fn fill_uniform(&mut self, log_base2k: usize, col_i: usize, size: usize, source: &mut Source); +} - /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. - fn add_dist_f64>( - &self, +pub trait FillDistF64 { + fn fill_dist_f64>( + &mut self, log_base2k: usize, - a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, dist: D, bound: f64, ); - - /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64); } -impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source) { +pub trait AddDistF64 { + /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. + fn add_dist_f64>( + &mut self, + log_base2k: usize, + col_i: usize, + log_k: usize, + source: &mut Source, + dist: D, + bound: f64, + ); +} + +pub trait FillNormal { + fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64); +} + +pub trait AddNormal { + /// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\]. + fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64); +} + +impl FillUniform for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn fill_uniform(&mut self, log_base2k: usize, col_i: usize, size: usize, source: &mut Source) { + let mut a: VecZnx<&mut [u8]> = self.to_mut(); let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; - let size: usize = a.n() * cols; - a.raw_mut()[..size] - .iter_mut() - .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + (0..size).for_each(|j| { + a.at_mut(col_i, j) + .iter_mut() + .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + }) } +} - fn add_dist_f64>( - &self, +impl FillDistF64 for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn fill_dist_f64>( + &mut self, log_base2k: usize, - a: &mut VecZnx, + col_i: usize, log_k: usize, source: &mut Source, dist: D, bound: f64, ) { + let mut a: VecZnx<&mut [u8]> = self.to_mut(); assert!( (bound.log2().ceil() as i64) < 64, "invalid bound: ceil(log2(bound))={} > 63", (bound.log2().ceil() as i64) ); - let log_base2k_rem: usize = log_k % log_base2k; + let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; if log_base2k_rem != 0 { - a.at_mut(a.cols() - 1).iter_mut().for_each(|a| { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = (dist_f64.round() as i64) << log_base2k_rem; + }); + } else { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = dist_f64.round() as i64 + }); + } + } +} + +impl AddDistF64 for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn add_dist_f64>( + &mut self, + log_base2k: usize, + col_i: usize, + log_k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + let mut a: VecZnx<&mut [u8]> = self.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; + + if log_base2k_rem != 0 { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -58,7 +134,7 @@ impl Sampling for Module { *a += (dist_f64.round() as i64) << log_base2k_rem; }); } else { - a.at_mut(a.cols() - 1).iter_mut().for_each(|a| { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { let mut dist_f64: f64 = dist.sample(source); while dist_f64.abs() > bound { dist_f64 = dist.sample(source) @@ -67,11 +143,16 @@ impl Sampling for Module { }); } } +} - fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { - self.add_dist_f64( +impl FillNormal for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { + self.fill_dist_f64( log_base2k, - a, + col_i, log_k, source, Normal::new(0.0, sigma).unwrap(), @@ -79,3 +160,206 @@ impl Sampling for Module { ); } } + +impl AddNormal for VecZnx +where + VecZnx: VecZnxToMut, +{ + fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { + self.add_dist_f64( + log_base2k, + col_i, + log_k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +impl FillDistF64 for VecZnxBig +where + VecZnxBig: VecZnxBigToMut, +{ + fn fill_dist_f64>( + &mut self, + log_base2k: usize, + col_i: usize, + log_k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; + + if log_base2k_rem != 0 { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = (dist_f64.round() as i64) << log_base2k_rem; + }); + } else { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = dist_f64.round() as i64 + }); + } + } +} + +impl AddDistF64 for VecZnxBig +where + VecZnxBig: VecZnxBigToMut, +{ + fn add_dist_f64>( + &mut self, + log_base2k: usize, + col_i: usize, + log_k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) { + let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1; + let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k; + + if log_base2k_rem != 0 { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << log_base2k_rem; + }); + } else { + a.at_mut(col_i, limb).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); + } + } +} + +impl FillNormal for VecZnxBig +where + VecZnxBig: VecZnxBigToMut, +{ + fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { + self.fill_dist_f64( + log_base2k, + col_i, + log_k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +impl AddNormal for VecZnxBig +where + VecZnxBig: VecZnxBigToMut, +{ + fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) { + self.add_dist_f64( + log_base2k, + col_i, + log_k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +#[cfg(test)] +mod tests { + use super::{AddNormal, FillUniform}; + use crate::vec_znx_ops::*; + use crate::znx_base::*; + use crate::{FFT64, Module, Stats, VecZnx}; + use sampling::source::Source; + + #[test] + fn vec_znx_fill_uniform() { + let n: usize = 4096; + let module: Module = Module::::new(n); + let log_base2k: usize = 17; + let size: usize = 5; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let one_12_sqrt: f64 = 0.28867513459481287; + (0..cols).for_each(|col_i| { + let mut a: VecZnx<_> = module.new_vec_znx(cols, size); + a.fill_uniform(log_base2k, col_i, size, &mut source); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..size).for_each(|limb_i| { + assert_eq!(a.at(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(col_i, log_base2k); + assert!( + (std - one_12_sqrt).abs() < 0.01, + "std={} ~!= {}", + std, + one_12_sqrt + ); + } + }) + }); + } + + #[test] + fn vec_znx_add_normal() { + let n: usize = 4096; + let module: Module = Module::::new(n); + let log_base2k: usize = 17; + let log_k: usize = 2 * 17; + let size: usize = 5; + let sigma: f64 = 3.2; + let bound: f64 = 6.0 * sigma; + let mut source: Source = Source::new([0u8; 32]); + let cols: usize = 2; + let zero: Vec = vec![0; n]; + let k_f64: f64 = (1u64 << log_k as u64) as f64; + (0..cols).for_each(|col_i| { + let mut a: VecZnx<_> = module.new_vec_znx(cols, size); + a.add_normal(log_base2k, col_i, log_k, &mut source, sigma, bound); + (0..cols).for_each(|col_j| { + if col_j != col_i { + (0..size).for_each(|limb_i| { + assert_eq!(a.at(col_j, limb_i), zero); + }) + } else { + let std: f64 = a.std(col_i, log_base2k) * k_f64; + assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma); + } + }) + }); + } +} diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs new file mode 100644 index 0000000..4c981c1 --- /dev/null +++ b/base2k/src/scalar_znx.rs @@ -0,0 +1,306 @@ +use crate::ffi::vec_znx; +use crate::znx_base::ZnxInfos; +use crate::{ + Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned, +}; +use rand::seq::SliceRandom; +use rand_core::RngCore; +use rand_distr::{Distribution, weighted::WeightedIndex}; +use sampling::source::Source; + +pub struct ScalarZnx { + pub(crate) data: D, + pub(crate) n: usize, + pub(crate) cols: usize, +} + +impl ZnxInfos for ScalarZnx { + fn cols(&self) -> usize { + self.cols + } + + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + 1 + } +} + +impl ZnxSliceSize for ScalarZnx { + fn sl(&self) -> usize { + self.n() + } +} + +impl DataView for ScalarZnx { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for ScalarZnx { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for ScalarZnx { + type Scalar = i64; +} + +impl + AsRef<[u8]>> ScalarZnx { + pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { + let choices: [i64; 3] = [-1, 0, 1]; + let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; + let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); + self.at_mut(col, 0) + .iter_mut() + .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); + } + + pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) { + assert!(hw <= self.n()); + self.at_mut(col, 0)[..hw] + .iter_mut() + .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); + self.at_mut(col, 0).shuffle(source); + } +} + +impl>> ScalarZnx { + pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { + n * cols * size_of::() + } + + pub(crate) fn new(n: usize, cols: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of::(n, cols)); + Self { + data: data.into(), + n, + cols, + } + } + + pub(crate) fn new_from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of::(n, cols)); + Self { + data: data.into(), + n, + cols, + } + } +} + +pub type ScalarZnxOwned = ScalarZnx>; + +pub(crate) fn bytes_of_scalar_znx(module: &Module, cols: usize) -> usize { + ScalarZnxOwned::bytes_of::(module.n(), cols) +} + +pub trait ScalarZnxAlloc { + fn bytes_of_scalar_znx(&self, cols: usize) -> usize; + fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned; + fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; +} + +impl ScalarZnxAlloc for Module { + fn bytes_of_scalar_znx(&self, cols: usize) -> usize { + ScalarZnxOwned::bytes_of::(self.n(), cols) + } + fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned { + ScalarZnxOwned::new::(self.n(), cols) + } + fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { + ScalarZnxOwned::new_from_bytes::(self.n(), cols, bytes) + } +} + +pub trait ScalarZnxOps { + fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef; + + /// Applies the automorphism X^i -> X^ik on the selected column of `a`. + fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: ScalarZnxToMut; +} + +impl ScalarZnxOps for Module { + fn scalar_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxToMut, + A: ScalarZnxToRef, + { + let a: ScalarZnx<&[u8]> = a.to_ref(); + let mut res: ScalarZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: ScalarZnxToMut, + { + let mut a: ScalarZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +impl ScalarZnx { + pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { + Self { data, n, cols } + } +} + +pub trait ScalarZnxToRef { + fn to_ref(&self) -> ScalarZnx<&[u8]>; +} + +pub trait ScalarZnxToMut { + fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>; +} + +impl ScalarZnxToMut for ScalarZnx> { + fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { + ScalarZnx { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + } + } +} + +impl VecZnxToMut for ScalarZnx> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: 1, + } + } +} + +impl ScalarZnxToRef for ScalarZnx> { + fn to_ref(&self) -> ScalarZnx<&[u8]> { + ScalarZnx { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + } + } +} + +impl VecZnxToRef for ScalarZnx> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: 1, + } + } +} + +impl ScalarZnxToMut for ScalarZnx<&mut [u8]> { + fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { + ScalarZnx { + data: self.data, + n: self.n, + cols: self.cols, + } + } +} + +impl VecZnxToMut for ScalarZnx<&mut [u8]> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + } + } +} + +impl ScalarZnxToRef for ScalarZnx<&mut [u8]> { + fn to_ref(&self) -> ScalarZnx<&[u8]> { + ScalarZnx { + data: self.data, + n: self.n, + cols: self.cols, + } + } +} + +impl VecZnxToRef for ScalarZnx<&mut [u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + } + } +} + +impl ScalarZnxToRef for ScalarZnx<&[u8]> { + fn to_ref(&self) -> ScalarZnx<&[u8]> { + ScalarZnx { + data: self.data, + n: self.n, + cols: self.cols, + } + } +} + +impl VecZnxToRef for ScalarZnx<&[u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + } + } +} diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs new file mode 100644 index 0000000..fa4ab10 --- /dev/null +++ b/base2k/src/scalar_znx_dft.rs @@ -0,0 +1,233 @@ +use std::marker::PhantomData; + +use crate::ffi::svp; +use crate::znx_base::ZnxInfos; +use crate::{ + Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, + alloc_aligned, +}; + +pub struct ScalarZnxDft { + data: D, + n: usize, + cols: usize, + _phantom: PhantomData, +} + +impl ZnxInfos for ScalarZnxDft { + fn cols(&self) -> usize { + self.cols + } + + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + 1 + } +} + +impl ZnxSliceSize for ScalarZnxDft { + fn sl(&self) -> usize { + self.n() + } +} + +impl DataView for ScalarZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for ScalarZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for ScalarZnxDft { + type Scalar = f64; +} + +pub(crate) fn bytes_of_scalar_znx_dft(module: &Module, cols: usize) -> usize { + ScalarZnxDftOwned::bytes_of(module, cols) +} + +impl>, B: Backend> ScalarZnxDft { + pub(crate) fn bytes_of(module: &Module, cols: usize) -> usize { + unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } + } + + pub(crate) fn new(module: &Module, cols: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of(module, cols)); + Self { + data: data.into(), + n: module.n(), + cols, + _phantom: PhantomData, + } + } + + pub(crate) fn new_from_bytes(module: &Module, cols: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, cols)); + Self { + data: data.into(), + n: module.n(), + cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDft { + pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { + Self { + data, + n, + cols, + _phantom: PhantomData, + } + } + + pub fn as_vec_znx_dft(self) -> VecZnxDft { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +pub type ScalarZnxDftOwned = ScalarZnxDft, B>; + +pub trait ScalarZnxDftToRef { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B>; +} + +pub trait ScalarZnxDftToMut { + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>; +} + +impl ScalarZnxDftToMut for ScalarZnxDft, B> { + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { + ScalarZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToRef for ScalarZnxDft, B> { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + ScalarZnxDft { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToMut for ScalarZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { + ScalarZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToRef for ScalarZnxDft<&mut [u8], B> { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + ScalarZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl ScalarZnxDftToRef for ScalarZnxDft<&[u8], B> { + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + ScalarZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToMut for ScalarZnxDft, B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for ScalarZnxDft, B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToMut for ScalarZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for ScalarZnxDft<&mut [u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for ScalarZnxDft<&[u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: 1, + _phantom: PhantomData, + } + } +} diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs new file mode 100644 index 0000000..1e0313a --- /dev/null +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -0,0 +1,103 @@ +use crate::ffi::svp; +use crate::ffi::vec_znx_dft::vec_znx_dft_t; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{ + Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft, + VecZnxDftToMut, VecZnxDftToRef, +}; + +pub trait ScalarZnxDftAlloc { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned; + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned; +} + +pub trait ScalarZnxDftOps { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarZnxToRef; + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxDftToRef; + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef; +} + +impl ScalarZnxDftAlloc for Module { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new(self, cols) + } + + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { + ScalarZnxDftOwned::bytes_of(self, cols) + } + + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDftOwned { + ScalarZnxDftOwned::new_from_bytes(self, cols, bytes) + } +} + +impl ScalarZnxDftOps for Module { + fn svp_prepare(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: ScalarZnxDftToMut, + A: ScalarZnxToRef, + { + unsafe { + svp::svp_prepare( + self.ptr, + res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, + a.to_ref().at_ptr(a_col, 0), + ) + } + } + + fn svp_apply(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + B: VecZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + let b: VecZnxDft<&[u8], FFT64> = b.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + b.at_ptr(b_col, 0) as *const vec_znx_dft_t, + b.size() as u64, + b.cols() as u64, + ) + } + } + + fn svp_apply_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: ScalarZnxDftToRef, + { + let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); + unsafe { + svp::svp_apply_dft_to_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, + res.at_ptr(res_col, 0) as *const vec_znx_dft_t, + res.size() as u64, + res.cols() as u64, + ) + } + } +} diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index f72ebaa..8db40f2 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -1,13 +1,19 @@ -use crate::{Encoding, Infos, VecZnx}; +use crate::znx_base::ZnxInfos; +use crate::{Decoding, VecZnx}; use rug::Float; use rug::float::Round; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; -impl VecZnx { - pub fn std(&self, poly_idx: usize, log_base2k: usize) -> f64 { - let prec: u32 = (self.cols() * log_base2k) as u32; +pub trait Stats { + /// Returns the standard devaition of the i-th polynomial. + fn std(&self, col_i: usize, log_base2k: usize) -> f64; +} + +impl> Stats for VecZnx { + fn std(&self, col_i: usize, log_base2k: usize) -> f64 { + let prec: u32 = (self.size() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); - self.decode_vec_float(poly_idx, log_base2k, &mut data); + self.decode_vec_float(col_i, log_base2k, &mut data); // std = sqrt(sum((xi - avg)^2) / n) let mut avg: Float = Float::with_val(prec, 0); data.iter().for_each(|x| { diff --git a/base2k/src/svp.rs b/base2k/src/svp.rs deleted file mode 100644 index 0e85a31..0000000 --- a/base2k/src/svp.rs +++ /dev/null @@ -1,276 +0,0 @@ -use crate::ffi::svp::{self, svp_ppol_t}; -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::{BACKEND, LAYOUT, Module, VecZnx, VecZnxDft, assert_alignement}; - -use crate::{Infos, alloc_aligned, cast_mut}; -use rand::seq::SliceRandom; -use rand_core::RngCore; -use rand_distr::{Distribution, weighted::WeightedIndex}; -use sampling::source::Source; - -pub struct Scalar { - pub n: usize, - pub data: Vec, - pub ptr: *mut i64, -} - -impl Module { - pub fn new_scalar(&self) -> Scalar { - Scalar::new(self.n()) - } -} - -impl Scalar { - pub fn new(n: usize) -> Self { - let mut data: Vec = alloc_aligned::(n); - let ptr: *mut i64 = data.as_mut_ptr(); - Self { - n: n, - data: data, - ptr: ptr, - } - } - - pub fn n(&self) -> usize { - self.n - } - - pub fn bytes_of(n: usize) -> usize { - n * std::mem::size_of::() - } - - pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self { - let size: usize = Self::bytes_of(n); - debug_assert!( - bytes.len() == size, - "invalid buffer: bytes.len()={} < self.bytes_of(n={})={}", - bytes.len(), - n, - size - ); - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()) - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: ptr, - } - } - } - - pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self { - let size: usize = Self::bytes_of(n); - debug_assert!( - bytes.len() == size, - "invalid buffer: bytes.len()={} < self.bytes_of(n={})={}", - bytes.len(), - n, - size - ); - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()) - } - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - data: Vec::new(), - ptr: ptr, - } - } - - pub fn as_ptr(&self) -> *const i64 { - self.ptr - } - - pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n) } - } - - pub fn raw_mut(&self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) } - } - - pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - let choices: [i64; 3] = [-1, 0, 1]; - let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; - let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); - self.data - .iter_mut() - .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); - } - - pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - assert!(hw <= self.n()); - self.data[..hw] - .iter_mut() - .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); - self.data.shuffle(source); - } - - pub fn as_vec_znx(&self) -> VecZnx { - VecZnx { - n: self.n, - size: 1, // TODO REVIEW IF NEED TO ADD size TO SCALAR - cols: 1, - layout: LAYOUT::COL, - data: Vec::new(), - ptr: self.ptr, - } - } -} - -pub trait ScalarOps { - fn bytes_of_scalar(&self) -> usize; - fn new_scalar(&self) -> Scalar; - fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar; - fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar; -} -impl ScalarOps for Module { - fn bytes_of_scalar(&self) -> usize { - Scalar::bytes_of(self.n()) - } - fn new_scalar(&self) -> Scalar { - Scalar::new(self.n()) - } - fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar { - Scalar::from_bytes(self.n(), bytes) - } - fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar { - Scalar::from_bytes_borrow(self.n(), tmp_bytes) - } -} - -pub struct SvpPPol { - pub n: usize, - pub data: Vec, - pub ptr: *mut u8, - pub backend: BACKEND, -} - -/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft]. -/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. -impl SvpPPol { - pub fn new(module: &Module) -> Self { - module.new_svp_ppol() - } - - /// Returns the ring degree of the [SvpPPol]. - pub fn n(&self) -> usize { - self.n - } - - pub fn bytes_of(module: &Module) -> usize { - module.bytes_of_svp_ppol() - } - - pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> SvpPPol { - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()); - assert_eq!(bytes.len(), module.bytes_of_svp_ppol()); - } - unsafe { - Self { - n: module.n(), - data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: bytes.as_mut_ptr(), - backend: module.backend(), - } - } - } - - pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> SvpPPol { - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol()); - } - Self { - n: module.n(), - data: Vec::new(), - ptr: tmp_bytes.as_mut_ptr(), - backend: module.backend(), - } - } - - /// Returns the number of cols of the [SvpPPol], which is always 1. - pub fn cols(&self) -> usize { - 1 - } -} - -pub trait SvpPPolOps { - /// Allocates a new [SvpPPol]. - fn new_svp_ppol(&self) -> SvpPPol; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [SvpPPol] through [SvpPPol::from_bytes] ro. - fn bytes_of_svp_ppol(&self) -> usize; - - /// Allocates a new [SvpPPol] from an array of bytes. - /// The array of bytes is owned by the [SvpPPol]. - /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol; - - /// Allocates a new [SvpPPol] from an array of bytes. - /// The array of bytes is borrowed by the [SvpPPol]. - /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol; - - /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar); - - /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of - /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); -} - -impl SvpPPolOps for Module { - fn new_svp_ppol(&self) -> SvpPPol { - let mut data: Vec = alloc_aligned::(self.bytes_of_svp_ppol()); - let ptr: *mut u8 = data.as_mut_ptr(); - SvpPPol { - data: data, - ptr: ptr, - n: self.n(), - backend: self.backend(), - } - } - - fn bytes_of_svp_ppol(&self) -> usize { - unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } - } - - fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol { - SvpPPol::from_bytes(self, bytes) - } - - fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol { - SvpPPol::from_bytes_borrow(self, tmp_bytes) - } - - fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) { - unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } - } - - fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { - unsafe { - svp::svp_apply_dft( - self.ptr, - c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, - a.ptr as *const svp_ppol_t, - b.as_ptr(), - b.cols() as u64, - b.n() as u64, - ) - } - } -} diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 7445b5b..d4b0b9c 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,11 +1,14 @@ -use crate::LAYOUT; +use crate::DataView; +use crate::DataViewMut; +use crate::ScalarZnx; +use crate::ZnxSliceSize; +use crate::ZnxZero; +use crate::alloc_aligned; +use crate::assert_alignement; use crate::cast_mut; -use crate::ffi::vec_znx; use crate::ffi::znx; -use crate::{Infos, Module}; -use crate::{alloc_aligned, assert_alignement}; -use itertools::izip; -use std::cmp::min; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use std::{cmp::min, fmt}; /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// Zn\[X\] with [i64] coefficients. @@ -17,205 +20,22 @@ use std::cmp::min; /// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// are small polynomials of Zn\[X\]. -#[derive(Clone)] -pub struct VecZnx { - /// Polynomial degree. +pub struct VecZnx { + pub data: D, pub n: usize, - - /// Stack size - pub size: usize, - - /// Stacking layout - pub layout: LAYOUT, - - /// Number of columns. pub cols: usize, - - /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. - pub data: Vec, - - /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). - pub ptr: *mut i64, + pub size: usize, } -pub fn bytes_of_vec_znx(n: usize, size: usize, cols: usize) -> usize { - n * size * cols * 8 -} - -impl VecZnx { - /// Returns a new struct implementing [VecZnx] with the provided data as backing array. - /// - /// The struct will take ownership of buf[..[VecZnx::bytes_of]] - /// - /// User must ensure that data is properly alligned and that - /// the size of data is equal to [VecZnx::bytes_of]. - pub fn from_bytes(n: usize, size: usize, cols: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(size > 0); - assert_eq!(bytes.len(), Self::bytes_of(n, size, cols)); - assert_alignement(bytes.as_ptr()); - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - VecZnx { - n: n, - size: size, - cols: cols, - layout: LAYOUT::COL, - data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), - ptr: ptr, - } - } +impl ZnxInfos for VecZnx { + fn cols(&self) -> usize { + self.cols } - pub fn from_bytes_borrow(n: usize, size: usize, cols: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert!(size > 0); - assert!(bytes.len() >= Self::bytes_of(n, size, cols)); - assert_alignement(bytes.as_ptr()); - } - VecZnx { - n: n, - size: size, - cols: cols, - layout: LAYOUT::COL, - data: Vec::new(), - ptr: bytes.as_mut_ptr() as *mut i64, - } + fn rows(&self) -> usize { + 1 } - pub fn bytes_of(n: usize, size: usize, cols: usize) -> usize { - bytes_of_vec_znx(n, size, cols) - } - - pub fn copy_from(&mut self, a: &VecZnx) { - copy_vec_znx_from(self, a); - } - - pub fn borrowing(&self) -> bool { - self.data.len() == 0 - } - - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::cols()]. - pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n * self.size * self.cols) } - } - - /// Returns a reference to backend slice of the receiver. - /// Total size is [VecZnx::n()] * [VecZnx::size()] * [VecZnx::cols()]. - pub fn raw_mut(&mut self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n * self.size * self.cols) } - } - - /// Returns a non-mutable pointer to the backedn slice of the receiver. - pub fn as_ptr(&self) -> *const i64 { - self.ptr - } - - /// Returns a mutable pointer to the backedn slice of the receiver. - pub fn as_mut_ptr(&mut self) -> *mut i64 { - self.ptr - } - - /// Returns a non-mutable pointer starting a the j-th column. - pub fn at_ptr(&self, i: usize) -> *const i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols); - } - let offset: usize = self.n * self.size * i; - self.ptr.wrapping_add(offset) - } - - /// Returns non-mutable reference to the ith-column. - /// The slice contains [VecZnx::size()] small polynomials, each of [VecZnx::n()] coefficients. - pub fn at(&self, i: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_ptr(i), self.n * self.size) } - } - - /// Returns a non-mutable pointer starting a the j-th column of the i-th polynomial. - pub fn at_poly_ptr(&self, i: usize, j: usize) -> *const i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.size); - assert!(j < self.cols); - } - let offset: usize = self.n * (self.size * j + i); - self.ptr.wrapping_add(offset) - } - - /// Returns non-mutable reference to the j-th column of the i-th polynomial. - /// The slice contains one small polynomial of [VecZnx::n()] coefficients. - pub fn at_poly(&self, i: usize, j: usize) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.at_poly_ptr(i, j), self.n) } - } - - /// Returns a mutable pointer starting a the j-th column. - pub fn at_mut_ptr(&self, i: usize) -> *mut i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.cols); - } - let offset: usize = self.n * self.size * i; - self.ptr.wrapping_add(offset) - } - - /// Returns mutable reference to the ith-column. - /// The slice contains [VecZnx::size()] small polynomials, each of [VecZnx::n()] coefficients. - pub fn at_mut(&mut self, i: usize) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i), self.n * self.size) } - } - - /// Returns a mutable pointer starting a the j-th column of the i-th polynomial. - pub fn at_poly_mut_ptr(&mut self, i: usize, j: usize) -> *mut i64 { - #[cfg(debug_assertions)] - { - assert!(i < self.size); - assert!(j < self.cols); - } - - let offset: usize = self.n * (self.size * j + i); - self.ptr.wrapping_add(offset) - } - - /// Returns mutable reference to the j-th column of the i-th polynomial. - /// The slice contains one small polynomial of [VecZnx::n()] coefficients. - pub fn at_poly_mut(&mut self, i: usize, j: usize) -> &mut [i64] { - let ptr: *mut i64 = self.at_poly_mut_ptr(i, j); - unsafe { std::slice::from_raw_parts_mut(ptr, self.n) } - } - - pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.cols * self.size) as u64, self.ptr) } - } - - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { - normalize(log_base2k, self, carry) - } - - pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { - rsh(log_base2k, self, k, carry) - } - - pub fn switch_degree(&self, a: &mut VecZnx) { - switch_degree(a, self) - } - - pub fn print(&self, poly: usize, cols: usize, n: usize) { - (0..cols).for_each(|i| println!("{}: {:?}", i, &self.at_poly(poly, i)[..n])) - } -} - -impl Infos for VecZnx { - /// Returns the base 2 logarithm of the [VecZnx] degree. - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - - /// Returns the [VecZnx] degree. fn n(&self) -> usize { self.n } @@ -223,119 +43,138 @@ impl Infos for VecZnx { fn size(&self) -> usize { self.size } +} - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of cols of the [VecZnx]. - fn cols(&self) -> usize { - self.cols - } - - /// Returns the number of rows of the [VecZnx]. - fn rows(&self) -> usize { - 1 +impl ZnxSliceSize for VecZnx { + fn sl(&self) -> usize { + self.n() * self.cols() } } -/// Copies the coefficients of `a` on the receiver. -/// Copy is done with the minimum size matching both backing arrays. -pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) { - let data_a: &[i64] = a.raw(); - let data_b: &mut [i64] = b.raw_mut(); - let size = min(data_b.len(), data_a.len()); - data_b[..size].copy_from_slice(&data_a[..size]) +impl DataView for VecZnx { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } } -impl VecZnx { - /// Allocates a new [VecZnx] composed of #cols polynomials of Z\[X\]. - pub fn new(n: usize, size: usize, cols: usize) -> Self { - #[cfg(debug_assertions)] - { - assert!(n > 0); - assert!(n & (n - 1) == 0); - assert!(size > 0); - assert!(cols > 0); - } - let mut data: Vec = alloc_aligned::(n * size * cols); - let ptr: *mut i64 = data.as_mut_ptr(); - Self { - n: n, - size: size, - layout: LAYOUT::COL, - cols: cols, - data: data, - ptr: ptr, - } +impl DataViewMut for VecZnx { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data } +} +impl> ZnxView for VecZnx { + type Scalar = i64; +} + +impl + AsRef<[u8]>> VecZnx { /// Truncates the precision of the [VecZnx] by k bits. /// /// # Arguments /// /// * `log_base2k`: the base two logarithm of the coefficients decomposition. /// * `k`: the number of bits of precision to drop. - pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize) { + pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize, col: usize) { if k == 0 { return; } - if !self.borrowing() { - self.data - .truncate((self.cols() - k / log_base2k) * self.n() * self.size()); - } - - self.cols -= k / log_base2k; + self.size -= k / log_base2k; let k_rem: usize = k % log_base2k; if k_rem != 0 { let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; - self.at_mut(self.cols() - 1) + self.at_mut(col, self.size() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } } } -pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { - let (n_in, n_out) = (a.n(), b.n()); - let (gap_in, gap_out): (usize, usize); - - if n_in > n_out { - (gap_in, gap_out) = (n_in / n_out, 1) - } else { - (gap_in, gap_out) = (1, n_out / n_in); - b.zero(); +impl>> VecZnx { + pub(crate) fn bytes_of(n: usize, cols: usize, size: usize) -> usize { + n * cols * size * size_of::() } - let cols = min(a.cols(), b.cols()); + pub(crate) fn new(n: usize, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(Self::bytes_of::(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + } + } - (0..cols).for_each(|i| { - izip!( - a.at(i).iter().step_by(gap_in), - b.at_mut(i).iter_mut().step_by(gap_out) - ) - .for_each(|(x_in, x_out)| *x_out = *x_in); - }); + pub(crate) fn new_from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of::(n, cols, size)); + Self { + data: data.into(), + n, + cols, + size, + } + } } -fn normalize_tmp_bytes(n: usize, size: usize) -> usize { - n * size * std::mem::size_of::() +impl VecZnx { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + } + } + + pub fn to_scalar_znx(self) -> ScalarZnx { + debug_assert_eq!( + self.size, 1, + "cannot convert VecZnx to ScalarZnx if cols: {} != 1", + self.cols + ); + ScalarZnx { + data: self.data, + n: self.n, + cols: self.cols, + } + } } -fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { +/// Copies the coefficients of `a` on the receiver. +/// Copy is done with the minimum size matching both backing arrays. +/// Panics if the cols do not match. +pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) +where + DataMut: AsMut<[u8]> + AsRef<[u8]>, + Data: AsRef<[u8]>, +{ + assert_eq!(b.cols(), a.cols()); + let data_a: &[i64] = a.raw(); + let data_b: &mut [i64] = b.raw_mut(); + let size = min(data_b.len(), data_a.len()); + data_b[..size].copy_from_slice(&data_a[..size]) +} + +#[allow(dead_code)] +fn normalize_tmp_bytes(n: usize) -> usize { + n * std::mem::size_of::() +} + +#[allow(dead_code)] +fn normalize + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { let n: usize = a.n(); - let size: usize = a.size(); debug_assert!( - tmp_bytes.len() >= normalize_tmp_bytes(n, size), - "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})", + tmp_bytes.len() >= normalize_tmp_bytes(n), + "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})", tmp_bytes.len(), n, - size, ); + #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()) @@ -345,462 +184,150 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); - (0..a.cols()).rev().for_each(|i| { + (0..a.size()).rev().for_each(|i| { znx::znx_normalize( - (n * size) as u64, + n as u64, log_base2k as u64, - a.at_mut_ptr(i), + a.at_mut_ptr(a_col, i), carry_i64.as_mut_ptr(), - a.at_mut_ptr(i), + a.at_mut_ptr(a_col, i), carry_i64.as_mut_ptr(), ) }); } } -pub fn rsh_tmp_bytes(n: usize, size: usize) -> usize { - n * size * std::mem::size_of::() -} - -pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { - let n: usize = a.n(); - let size: usize = a.size(); - - #[cfg(debug_assertions)] +impl VecZnx +where + VecZnx: VecZnxToMut + ZnxInfos, +{ + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. + pub fn extract_column(&mut self, self_col: usize, a: &R, a_col: usize) + where + R: VecZnxToRef + ZnxInfos, { - assert!( - tmp_bytes.len() >= rsh_tmp_bytes(n, size), - "invalid carry: carry.len()/8={} < rsh_tmp_bytes({}, {})", - tmp_bytes.len() >> 3, - n, - size, - ); - assert_alignement(tmp_bytes.as_ptr()); - } - - let cols: usize = a.cols(); - let cols_steps: usize = k / log_base2k; - - a.raw_mut().rotate_right(n * size * cols_steps); - unsafe { - znx::znx_zero_i64_ref((n * size * cols_steps) as u64, a.as_mut_ptr()); - } - - let k_rem = k % log_base2k; - - if k_rem != 0 { - let carry_i64: &mut [i64] = cast_mut(tmp_bytes); - - unsafe { - znx::znx_zero_i64_ref((n * size) as u64, carry_i64.as_mut_ptr()); - } - - let log_base2k: usize = log_base2k; - - (cols_steps..cols).for_each(|i| { - izip!(carry_i64.iter_mut(), a.at_mut(i).iter_mut()).for_each(|(ci, xi)| { - *xi += *ci << log_base2k; - *ci = get_base_k_carry(*xi, k_rem); - *xi = (*xi - *ci) >> k_rem; - }); - }) - } -} - -#[inline(always)] -fn get_base_k_carry(x: i64, k: usize) -> i64 { - (x << 64 - k) >> (64 - k) -} - -pub trait VecZnxOps { - /// Allocates a new [VecZnx]. - /// - /// # Arguments - /// - /// * `cols`: the number of cols. - fn new_vec_znx(&self, size: usize, cols: usize) -> VecZnx; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnx] through [VecZnx::from_bytes]. - fn bytes_of_vec_znx(&self, size: usize, cols: usize) -> usize; - - fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize; - - /// c <- a + b. - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); - - /// b <- b + a. - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); - - /// c <- a - b. - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); - - /// b <- a - b. - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx); - - /// b <- b - a. - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx); - - /// b <- -a. - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); - - /// b <- -b. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx); - - /// b <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); - - /// a <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); - - /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); - - /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); - - /// Splits b into subrings and copies them them into a. - /// - /// # Panics - /// - /// This method requires that all [VecZnx] of b have the same ring degree - /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx); - - /// Merges the subrings a into b. - /// - /// # Panics - /// - /// This method requires that all [VecZnx] of a have the same ring degree - /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); -} - -impl VecZnxOps for Module { - fn new_vec_znx(&self, size: usize, cols: usize) -> VecZnx { - VecZnx::new(self.n(), size, cols) - } - - fn bytes_of_vec_znx(&self, size: usize, cols: usize) -> usize { - bytes_of_vec_znx(self.n(), size, cols) - } - - fn vec_znx_normalize_tmp_bytes(&self, size: usize) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * size } - } - - // c <- a + b - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let n: usize = self.n(); #[cfg(debug_assertions)] { - assert_eq!(c.n(), n); - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); + assert!(self_col < self.cols()); + assert!(a_col < a.cols()); } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - c.as_mut_ptr(), - c.cols() as u64, - (n * c.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - b.as_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - ) - } - } - // b <- a + b - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - b.as_mut_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - b.as_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - ) - } - } + let min_size: usize = self.size.min(a.size()); + let max_size: usize = self.size; - // c <- a + b - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(c.n(), n); - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - c.as_mut_ptr(), - c.cols() as u64, - (n * c.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - b.as_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - ) - } - } + let mut self_mut: VecZnx<&mut [u8]> = self.to_mut(); + let a_ref: VecZnx<&[u8]> = a.to_ref(); - // b <- a - b - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - b.as_mut_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - b.as_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - ) - } - } - - // b <- b - a - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - b.as_mut_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - b.as_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - ) - } - } - - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_negate( - self.ptr, - b.as_mut_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - ) - } - } - - fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - } - unsafe { - vec_znx::vec_znx_negate( - self.ptr, - a.as_mut_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - ) - } - } - - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_rotate( - self.ptr, - k, - b.as_mut_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - ) - } - } - - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - } - unsafe { - vec_znx::vec_znx_rotate( - self.ptr, - k, - a.as_mut_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - ) - } - } - - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each cols. - /// - /// # Arguments - /// - /// * `a`: input. - /// * `b`: output. - /// * `k`: the power to which to map each coefficients. - /// * `a_cols`: the number of a_cols on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `a` is greater than `a.cols()`. - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - b.as_mut_ptr(), - b.cols() as u64, - (n * b.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - ); - } - } - - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each cols. - /// - /// # Arguments - /// - /// * `a`: input and output. - /// * `k`: the power to which to map each coefficients. - /// * `a_cols`: the number of cols on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `cols` is greater than `self.cols()`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - a.as_mut_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - a.as_ptr(), - a.cols() as u64, - (n * a.size()) as u64, - ); - } - } - - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx) { - let (n_in, n_out) = (a.n(), b[0].n()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - b[1..].iter().for_each(|bi| { - debug_assert_eq!( - bi.n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) + (0..min_size).for_each(|i: usize| { + self_mut + .at_mut(self_col, i) + .copy_from_slice(a_ref.at(a_col, i)); }); - b.iter_mut().enumerate().for_each(|(i, bi)| { - if i == 0 { - switch_degree(bi, a); - self.vec_znx_rotate(-1, buf, a); - } else { - switch_degree(bi, buf); - self.vec_znx_rotate_inplace(-1, buf); + (min_size..max_size).for_each(|i| { + self_mut.zero_at(self_col, i); + }); + } +} + +impl> fmt::Display for VecZnx { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnx(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; } - }) - } - - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec) { - let (n_in, n_out) = (b.n(), a[0].n()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - a[1..].iter().for_each(|ai| { - debug_assert_eq!( - ai.n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) - }); - - a.iter().enumerate().for_each(|(_, ai)| { - switch_degree(b, ai); - self.vec_znx_rotate_inplace(-1, b); - }); - - self.vec_znx_rotate_inplace(a.len() as i64, b); + } + Ok(()) + } +} + +pub type VecZnxOwned = VecZnx>; +pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; +pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; + +pub trait VecZnxToRef { + fn to_ref(&self) -> VecZnx<&[u8]>; +} + +pub trait VecZnxToMut { + fn to_mut(&mut self) -> VecZnx<&mut [u8]>; +} + +impl VecZnxToMut for VecZnx> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: self.size, + } + } +} + +impl VecZnxToRef for VecZnx> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: self.size, + } + } +} + +impl VecZnxToMut for VecZnx<&mut [u8]> { + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } +} + +impl VecZnxToRef for VecZnx<&mut [u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } +} + +impl VecZnxToRef for VecZnx<&[u8]> { + fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } } } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 705a5ec..2bf4dcc 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,90 +1,26 @@ -use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; -use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement}; +use crate::ffi::vec_znx_big; +use crate::znx_base::{ZnxInfos, ZnxView}; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; +use std::fmt; +use std::marker::PhantomData; -pub struct VecZnxBig { - pub data: Vec, - pub ptr: *mut u8, - pub n: usize, - pub size: usize, - pub cols: usize, - pub layout: LAYOUT, - pub backend: BACKEND, +pub struct VecZnxBig { + data: D, + n: usize, + cols: usize, + size: usize, + _phantom: PhantomData, } -impl VecZnxBig { - /// Returns a new [VecZnxBig] with the provided data as backing array. - /// User must ensure that data is properly alligned and that - /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols)); - assert_alignement(bytes.as_ptr()) - }; - unsafe { - Self { - data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: bytes.as_mut_ptr(), - n: module.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - backend: module.backend, - } - } +impl ZnxInfos for VecZnxBig { + fn cols(&self) -> usize { + self.cols } - pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols)); - assert_alignement(bytes.as_ptr()); - } - Self { - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - n: module.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - backend: module.backend, - } + fn rows(&self) -> usize { + 1 } - pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { - VecZnxDft { - data: Vec::new(), - ptr: self.ptr, - n: self.n, - size: self.size, - layout: LAYOUT::COL, - cols: self.cols, - backend: self.backend, - } - } - - pub fn backend(&self) -> BACKEND { - self.backend - } - - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is cols * n. - pub fn raw(&self, module: &Module) -> &[T] { - let ptr: *const T = self.ptr as *const T; - let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } - } -} - -impl Infos for VecZnxBig { - /// Returns the base 2 logarithm of the [VecZnx] degree. - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - - /// Returns the [VecZnx] degree. fn n(&self) -> usize { self.n } @@ -92,270 +28,217 @@ impl Infos for VecZnxBig { fn size(&self) -> usize { self.size } +} - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of cols of the [VecZnx]. - fn cols(&self) -> usize { - self.cols - } - - /// Returns the number of rows of the [VecZnx]. - fn rows(&self) -> usize { - 1 +impl ZnxSliceSize for VecZnxBig { + fn sl(&self) -> usize { + self.n() * self.cols() } } -pub trait VecZnxBigOps { - /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig; - - /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - /// - /// Behavior: takes ownership of the backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxBig]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig; - - /// Returns a new [VecZnxBig] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxBig]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize; - - /// b <- b - a - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); - - /// c <- b - a - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); - - /// c <- b + a - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); - - /// b <- b + a - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); - - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; - - /// b <- normalize(a) - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); - - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; - - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - res: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ); - - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig); - - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig); +impl DataView for VecZnxBig { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } } -impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig { - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_big(size, cols)); - let ptr: *mut u8 = data.as_mut_ptr(); +impl DataViewMut for VecZnxBig { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for VecZnxBig { + type Scalar = i64; +} + +pub(crate) fn bytes_of_vec_znx_big(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } +} + +impl>, B: Backend> VecZnxBig { + pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(bytes_of_vec_znx_big(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, + } + } + + pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == bytes_of_vec_znx_big(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBig { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBig +where + VecZnxBig: VecZnxBigToMut + ZnxInfos, +{ + // Consumes the VecZnxBig to return a VecZnx. + // Useful when no normalization is needed. + pub fn to_vec_znx_small(self) -> VecZnx { + VecZnx { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } + + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. + pub fn extract_column(&mut self, self_col: usize, a: &VecZnxBig, a_col: usize) + where + VecZnxBig: VecZnxBigToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert!(self_col < self.cols()); + assert!(a_col < a.cols()); + } + + let min_size: usize = self.size.min(a.size()); + let max_size: usize = self.size; + + let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut(); + let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref(); + + (0..min_size).for_each(|i: usize| { + self_mut + .at_mut(self_col, i) + .copy_from_slice(a_ref.at(a_col, i)); + }); + + (min_size..max_size).for_each(|i| { + self_mut.zero_at(self_col, i); + }); + } +} + +pub type VecZnxBigOwned = VecZnxBig, B>; + +pub trait VecZnxBigToRef { + fn to_ref(&self) -> VecZnxBig<&[u8], B>; +} + +pub trait VecZnxBigToMut { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>; +} + +impl VecZnxBigToMut for VecZnxBig, B> { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { VecZnxBig { - data: data, - ptr: ptr, - n: self.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - backend: self.backend(), - } - } - - fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, size, cols, bytes) - } - - fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, size, cols, tmp_bytes) - } - - fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize * size } - } - - fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - vec_znx_big::vec_znx_big_sub_small_a( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.cols() as u64, - a.as_ptr(), - a.cols() as u64, - a.n() as u64, - b.ptr as *mut vec_znx_big_t, - b.cols() as u64, - ) - } - } - - fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_sub_small_a( - self.ptr, - c.ptr as *mut vec_znx_big_t, - c.cols() as u64, - a.as_ptr(), - a.cols() as u64, - a.n() as u64, - b.ptr as *mut vec_znx_big_t, - b.cols() as u64, - ) - } - } - - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_add_small( - self.ptr, - c.ptr as *mut vec_znx_big_t, - c.cols() as u64, - b.ptr as *mut vec_znx_big_t, - b.cols() as u64, - a.as_ptr(), - a.cols() as u64, - a.n() as u64, - ) - } - } - - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - vec_znx_big::vec_znx_big_add_small( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.cols() as u64, - b.ptr as *mut vec_znx_big_t, - b.cols() as u64, - a.as_ptr(), - a.cols() as u64, - a.n() as u64, - ) - } - } - - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { - debug_assert!( - tmp_bytes.len() >= ::vec_znx_big_normalize_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", - tmp_bytes.len(), - ::vec_znx_big_normalize_tmp_bytes(self) - ); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_big::vec_znx_big_normalize_base2k( - self.ptr, - log_base2k as u64, - b.as_mut_ptr(), - b.cols() as u64, - b.n() as u64, - a.ptr as *mut vec_znx_big_t, - a.cols() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - res: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ) { - debug_assert!( - tmp_bytes.len() >= ::vec_znx_big_range_normalize_base2k_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", - tmp_bytes.len(), - ::vec_znx_big_range_normalize_base2k_tmp_bytes(self) - ); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_big::vec_znx_big_range_normalize_base2k( - self.ptr, - log_base2k as u64, - res.as_mut_ptr(), - res.cols() as u64, - res.n() as u64, - a.ptr as *mut vec_znx_big_t, - a_range_begin as u64, - a_range_xend as u64, - a_range_step as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - b.ptr as *mut vec_znx_big_t, - b.cols() as u64, - a.ptr as *mut vec_znx_big_t, - a.cols() as u64, - ); - } - } - - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - a.ptr as *mut vec_znx_big_t, - a.cols() as u64, - a.ptr as *mut vec_znx_big_t, - a.cols() as u64, - ); + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, } } } + +impl VecZnxBigToRef for VecZnxBig, B> { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBigToMut for VecZnxBig<&mut [u8], B> { + fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { + VecZnxBig { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBigToRef for VecZnxBig<&mut [u8], B> { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxBigToRef for VecZnxBig<&[u8], B> { + fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl> fmt::Display for VecZnxBig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnxBig(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs new file mode 100644 index 0000000..8208c97 --- /dev/null +++ b/base2k/src/vec_znx_big_ops.rs @@ -0,0 +1,632 @@ +use crate::ffi::vec_znx; +use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; +use crate::{ + Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch, + VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big, +}; + +pub trait VecZnxBigAlloc { + /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned; + + /// Returns a new [VecZnxBig] with the provided bytes array as backing array. + /// + /// Behavior: takes ownership of the backing array. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials.. + /// * `size`: the number of polynomials per column. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; + + // /// Returns a new [VecZnxBig] with the provided bytes array as backing array. + // /// + // /// Behavior: the backing array is only borrowed. + // /// + // /// # Arguments + // /// + // /// * `cols`: the number of polynomials.. + // /// * `size`: the number of polynomials per column. + // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. + // /// + // /// # Panics + // /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. + // fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; +} + +pub trait VecZnxBigOps { + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef; + + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; + + /// Adds `a` to `b` and stores the result on `c`. + fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef; + + /// Adds `a` to `b` and stores the result on `b`. + fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; + + /// Subtracts `a` to `b` and stores the result on `c`. + fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef; + + /// Subtracts `a` from `b` and stores the result on `b`. + fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; + + /// Subtracts `b` from `a` and stores the result on `b`. + fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; + + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef; + + /// Subtracts `a` from `res` and stores the result on `res`. + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; + + /// Subtracts `b` from `a` and stores the result on `c`. + fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef; + + /// Subtracts `res` from `a` and stores the result on `res`. + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef; + + /// Negates `a` inplace. + fn vec_znx_big_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; + + /// Normalizes `a` and stores the result on `b`. + /// + /// # Arguments + /// + /// * `log_base2k`: normalization basis. + /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. + fn vec_znx_big_normalize( + &self, + log_base2k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxBigToRef; + + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. + fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef; + + /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut; +} + +pub trait VecZnxBigScratch { + /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; +} + +impl VecZnxBigAlloc for Module { + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned { + VecZnxBig::new(self, cols, size) + } + + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + VecZnxBig::new_from_bytes(self, cols, size, bytes) + } + + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { + bytes_of_vec_znx_big(self, cols, size) + } +} + +impl VecZnxBigOps for Module { + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } + + fn vec_znx_big_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_small_b(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_small_a(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + B: VecZnxBigToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnxBig<&[u8], FFT64> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_big_add_small(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + B: VecZnxToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_add_small_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_big_negate_inplace(&self, a: &mut A, res_col: usize) + where + A: VecZnxBigToMut, + { + let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_negate( + self.ptr, + a.at_mut_ptr(res_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(res_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_big_normalize( + &self, + log_base2k: usize, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + //(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes. + // In the FFT backend the tmp sizes are same but will be different in the NTT backend + // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); + // assert_alignement(tmp_bytes.as_ptr()); + } + + let (tmp_bytes, _) = scratch.tmp_slice(::vec_znx_big_normalize_tmp_bytes( + &self, + )); + unsafe { + vec_znx::vec_znx_normalize_base2k( + self.ptr, + log_base2k as u64, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } + + fn vec_znx_big_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxBigToRef, + { + let a: VecZnxBig<&[u8], FFT64> = a.to_ref(); + let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxBigToMut, + { + let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } +} + +impl VecZnxBigScratch for Module { + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + ::vec_znx_normalize_tmp_bytes(self) + } +} diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 8b31ea6..7b4ec29 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -1,114 +1,35 @@ -use crate::ffi::vec_znx_big::vec_znx_big_t; +use std::marker::PhantomData; + use crate::ffi::vec_znx_dft; -use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; -use crate::{BACKEND, Infos, LAYOUT, Module, VecZnxBig, assert_alignement}; -use crate::{DEFAULTALIGN, VecZnx, alloc_aligned}; +use crate::znx_base::ZnxInfos; +use crate::{ + Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned, +}; +use std::fmt; -pub struct VecZnxDft { - pub data: Vec, - pub ptr: *mut u8, - pub n: usize, - pub size: usize, - pub layout: LAYOUT, - pub cols: usize, - pub backend: BACKEND, +pub struct VecZnxDft { + pub(crate) data: D, + pub(crate) n: usize, + pub(crate) cols: usize, + pub(crate) size: usize, + pub(crate) _phantom: PhantomData, } -impl VecZnxDft { - /// Returns a new [VecZnxDft] with the provided data as backing array. - /// User must ensure that data is properly alligned and that - /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft { - #[cfg(debug_assertions)] - { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols)); - assert_alignement(bytes.as_ptr()) - } - unsafe { - VecZnxDft { - data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: bytes.as_mut_ptr(), - n: module.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - backend: module.backend, - } - } - } - - pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft { - #[cfg(debug_assertions)] - { - assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols)); - assert_alignement(bytes.as_ptr()); - } - VecZnxDft { - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - n: module.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - backend: module.backend, - } - } - - /// Cast a [VecZnxDft] into a [VecZnxBig]. - /// The returned [VecZnxBig] shares the backing array - /// with the original [VecZnxDft]. - pub fn as_vec_znx_big(&mut self) -> VecZnxBig { - VecZnxBig { - data: Vec::new(), - ptr: self.ptr, - n: self.n, - layout: LAYOUT::COL, - size: self.size, - cols: self.cols, - backend: self.backend, - } - } - - pub fn backend(&self) -> BACKEND { - self.backend - } - - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is cols * n. - pub fn raw(&self, module: &Module) -> &[T] { - let ptr: *const T = self.ptr as *const T; - let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } - } - - pub fn at(&self, module: &Module, col_i: usize) -> &[T] { - &self.raw::(module)[col_i * module.n()..(col_i + 1) * module.n()] - } - - /// Returns a mutable reference of `T` of the entire contiguous array of the [VecZnxDft]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is cols * n. - pub fn raw_mut(&self, module: &Module) -> &mut [T] { - let ptr: *mut T = self.ptr as *mut T; - let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } - } - - pub fn at_mut(&self, module: &Module, col_i: usize) -> &mut [T] { - &mut self.raw_mut::(module)[col_i * module.n()..(col_i + 1) * module.n()] +impl VecZnxDft { + pub fn into_big(self) -> VecZnxBig { + VecZnxBig::::from_data(self.data, self.n, self.cols, self.size) } } -impl Infos for VecZnxDft { - /// Returns the base 2 logarithm of the [VecZnx] degree. - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ +impl ZnxInfos for VecZnxDft { + fn cols(&self) -> usize { + self.cols + } + + fn rows(&self) -> usize { + 1 } - /// Returns the [VecZnx] degree. fn n(&self) -> usize { self.n } @@ -116,254 +37,206 @@ impl Infos for VecZnxDft { fn size(&self) -> usize { self.size } +} - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of cols of the [VecZnx]. - fn cols(&self) -> usize { - self.cols - } - - /// Returns the number of rows of the [VecZnx]. - fn rows(&self) -> usize { - 1 +impl ZnxSliceSize for VecZnxDft { + fn sl(&self) -> usize { + self.n() * self.cols() } } -pub trait VecZnxDftOps { - /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// Behavior: takes ownership of the backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// Behavior: the backing array is only borrowed. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft; - - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - /// - /// # Arguments - /// - /// * `cols`: the number of cols of the [VecZnxDft]. - /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - /// - /// # Panics - /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. - fn vec_znx_idft_tmp_bytes(&self) -> usize; - - /// b <- IDFT(a), uses a as scratch space. - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft); - - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]); - - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx); - - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft); - - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]); - - fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize; +impl DataView for VecZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } } -impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft { - let mut data: Vec = alloc_aligned::(self.bytes_of_vec_znx_dft(size, cols)); - let ptr: *mut u8 = data.as_mut_ptr(); - VecZnxDft { - data: data, - ptr: ptr, - n: self.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - backend: self.backend(), +impl DataViewMut for VecZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for VecZnxDft { + type Scalar = f64; +} + +pub(crate) fn bytes_of_vec_znx_dft(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } +} + +impl>, B: Backend> VecZnxDft { + pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { + let data = alloc_aligned::(bytes_of_vec_znx_dft(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, } } - fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes(self, size, cols, tmp_bytes) - } - - fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, size, cols, tmp_bytes) - } - - fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize * size } - } - - fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { - unsafe { - vec_znx_dft::vec_znx_idft_tmp_a( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.cols() as u64, - a.ptr as *mut vec_znx_dft_t, - a.cols() as u64, - ) + pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == bytes_of_vec_znx_dft(module, cols, size)); + Self { + data: data.into(), + n: module.n(), + cols, + size, + _phantom: PhantomData, } } +} - fn vec_znx_idft_tmp_bytes(&self) -> usize { - unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } - } - - /// b <- DFT(a) - /// - /// # Panics - /// If b.cols < a_cols - fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) { - unsafe { - vec_znx_dft::vec_znx_dft( - self.ptr, - b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, - a.as_ptr(), - a.cols() as u64, - a.n() as u64, - ) - } - } - - // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. - fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) { +impl VecZnxDft +where + VecZnxDft: VecZnxDftToMut + ZnxInfos, +{ + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. + pub fn extract_column(&mut self, self_col: usize, a: &VecZnxDft, a_col: usize) + where + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { #[cfg(debug_assertions)] { - assert!( - tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_idft_tmp_bytes(self) - ); - assert_alignement(tmp_bytes.as_ptr()) + assert!(self_col < self.cols()); + assert!(a_col < a.cols()); } - unsafe { - vec_znx_dft::vec_znx_idft( - self.ptr, - b.ptr as *mut vec_znx_big_t, - b.cols() as u64, - a.ptr as *const vec_znx_dft_t, - a.cols() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) { - unsafe { - vec_znx_dft::vec_znx_dft_automorphism( - self.ptr, - k, - b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, - a.ptr as *const vec_znx_dft_t, - a.cols() as u64, - [0u8; 0].as_mut_ptr(), - ); - } - } + let min_size: usize = self.size.min(a.size()); + let max_size: usize = self.size; - fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_dft_automorphism_tmp_bytes(self) - ); - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_dft::vec_znx_dft_automorphism( - self.ptr, - k, - a.ptr as *mut vec_znx_dft_t, - a.cols() as u64, - a.ptr as *const vec_znx_dft_t, - a.cols() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } + let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); - fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize { - unsafe { - std::cmp::max( - vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize, - DEFAULTALIGN, - ) - } - } -} - -#[cfg(test)] -mod tests { - use crate::{BACKEND, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned}; - use itertools::izip; - use sampling::source::{Source, new_seed}; - - #[test] - fn test_automorphism_dft() { - let module: Module = Module::new(128, BACKEND::FFT64); - - let cols: usize = 2; - let log_base2k: usize = 17; - let mut a: VecZnx = module.new_vec_znx(1, cols); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, cols); - - let mut source: Source = Source::new(new_seed()); - module.fill_uniform(log_base2k, &mut a, cols, &mut source); - - let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); - - let p: i64 = -5; - - // a_dft <- DFT(a) - module.vec_znx_dft(&mut a_dft, &a); - - // a_dft <- AUTO(a_dft) - module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes); - - // a <- AUTO(a) - module.vec_znx_automorphism_inplace(p, &mut a); - - // b_dft <- DFT(AUTO(a)) - module.vec_znx_dft(&mut b_dft, &a); - - let a_f64: &[f64] = a_dft.raw(&module); - let b_f64: &[f64] = b_dft.raw(&module); - izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| { - assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs()); + (0..min_size).for_each(|i: usize| { + self_mut + .at_mut(self_col, i) + .copy_from_slice(a_ref.at(a_col, i)); }); - module.free() + (min_size..max_size).for_each(|i| { + self_mut.zero_at(self_col, i); + }); + } +} + +pub type VecZnxDftOwned = VecZnxDft, B>; + +impl VecZnxDft { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + _phantom: PhantomData, + } + } +} + +pub trait VecZnxDftToRef { + fn to_ref(&self) -> VecZnxDft<&[u8], B>; +} + +pub trait VecZnxDftToMut { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; +} + +impl VecZnxDftToMut for VecZnxDft, B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for VecZnxDft, B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToMut for VecZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for VecZnxDft<&mut [u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl VecZnxDftToRef for VecZnxDft<&[u8], B> { + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl> fmt::Display for VecZnxDft { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnxDft(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) } } diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs new file mode 100644 index 0000000..e4d6c33 --- /dev/null +++ b/base2k/src/vec_znx_dft_ops.rs @@ -0,0 +1,287 @@ +use crate::ffi::{vec_znx_big, vec_znx_dft}; +use crate::vec_znx_dft::bytes_of_vec_znx_dft; +use crate::znx_base::ZnxInfos; +use crate::{ + Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, + ZnxSliceSize, +}; +use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero}; +use std::cmp::min; + +pub trait VecZnxDftAlloc { + /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// Behavior: takes ownership of the backing array. + /// + /// # Arguments + /// + /// * `cols`: the number of cols of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; + + /// Returns a new [VecZnxDft] with the provided bytes array as backing array. + /// + /// # Arguments + /// + /// * `cols`: the number of cols of the [VecZnxDft]. + /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. + /// + /// # Panics + /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; +} + +pub trait VecZnxDftOps { + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. + fn vec_znx_idft_tmp_bytes(&self) -> usize; + + fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef; + + fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef; + + /// b <- IDFT(a), uses a as scratch space. + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut; + + /// Consumes a to return IDFT(a) in big coeff space. + fn vec_znx_idft_consume(&self, a: VecZnxDft) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut; + + fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxBigToMut, + A: VecZnxDftToRef; + + fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxToRef; +} + +impl VecZnxDftAlloc for Module { + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned { + VecZnxDftOwned::new(&self, cols, size) + } + + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + VecZnxDftOwned::new_from_bytes(self, cols, size, bytes) + } + + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { + bytes_of_vec_znx_dft(self, cols, size) + } +} + +impl VecZnxDftOps for Module { + fn vec_znx_dft_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + D: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + + fn vec_znx_dft_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = res_mut.size().min(a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_dft_add( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1, + res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1, + ); + }); + } + } + + fn vec_znx_dft_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let min_size: usize = min(res_mut.size(), a_ref.size()); + + (0..min_size).for_each(|j| { + res_mut + .at_mut(res_col, j) + .copy_from_slice(a_ref.at(a_col, j)); + }); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + + fn vec_znx_idft_tmp_a(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize) + where + R: VecZnxBigToMut, + A: VecZnxDftToMut, + { + let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); + + let min_size: usize = min(res_mut.size(), a_mut.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + } + + fn vec_znx_idft_consume(&self, mut a: VecZnxDft) -> VecZnxBig + where + VecZnxDft: VecZnxDftToMut, + { + let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut(); + + unsafe { + // Rev col and rows because ZnxDft.sl() >= ZnxBig.sl() + (0..a_mut.size()).for_each(|j| { + (0..a_mut.cols()).for_each(|i| { + vec_znx_dft::vec_znx_idft_tmp_a( + self.ptr, + a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + ) + }); + }); + } + + a.into_big() + } + + fn vec_znx_idft_tmp_bytes(&self) -> usize { + unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize } + } + + /// b <- DFT(a) + /// + /// # Panics + /// If b.cols < a_col + fn vec_znx_dft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxDftToMut, + A: VecZnxToRef, + { + let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); + let a_ref: crate::VecZnx<&[u8]> = a.to_ref(); + + let min_size: usize = min(res_mut.size(), a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_dft( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t, + 1 as u64, + a_ref.at_ptr(a_col, j), + 1 as u64, + a_ref.sl() as u64, + ) + }); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }); + } + } + + // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. + fn vec_znx_idft(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxBigToMut, + A: VecZnxDftToRef, + { + let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut(); + let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref(); + + let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes()); + + let min_size: usize = min(res_mut.size(), a_ref.size()); + + unsafe { + (0..min_size).for_each(|j| { + vec_znx_dft::vec_znx_idft( + self.ptr, + res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t, + 1 as u64, + a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t, + 1 as u64, + tmp_bytes.as_mut_ptr(), + ) + }); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }); + } + } +} diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs new file mode 100644 index 0000000..b97e6b7 --- /dev/null +++ b/base2k/src/vec_znx_ops.rs @@ -0,0 +1,694 @@ +use crate::ffi::vec_znx; +use crate::{ + Backend, Module, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, + ZnxViewMut, ZnxZero, +}; +use itertools::izip; +use std::cmp::min; + +pub trait VecZnxAlloc { + /// Allocates a new [VecZnx]. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials. + /// * `size`: the number small polynomials per column. + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned; + + /// Instantiates a new [VecZnx] from a slice of bytes. + /// The returned [VecZnx] takes ownership of the slice of bytes. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials. + /// * `size`: the number small polynomials per column. + /// + /// # Panic + /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; + + /// Returns the number of bytes necessary to allocate + /// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes] + /// or [VecZnxOps::new_vec_znx_from_bytes_borrow]. + fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; +} + +pub trait VecZnxOps { + /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. + fn vec_znx_normalize(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; + + /// Normalizes the selected column of `a`. + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut; + + /// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`. + fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef; + + /// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`. + fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; + + /// Adds the selected column of `a` on the selected column and limb of `res`. + fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef; + + /// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`. + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef; + + /// Subtracts the selected column of `a` from the selected column of `res` inplace. + /// + /// res[res_col] -= a[a_col] + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; + + /// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res` + /// + /// res[res_col] = a[a_col] - res[res_col] + fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; + + /// Subtracts the selected column of `a` on the selected column and limb of `res`. + fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef; + + // Negates the selected column of `a` and stores the result in `res_col` of `res`. + fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; + + /// Negates the selected column of `a`. + fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxToMut; + + /// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`. + fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; + + /// Multiplies the selected column of `a` by X^k. + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; + + /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`. + fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; + + /// Applies the automorphism X^i -> X^ik on the selected column of `a`. + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut; + + /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. + /// + /// # Panics + /// + /// This method requires that all [VecZnx] of b have the same ring degree + /// and that b.n() * b.len() <= a.n() + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef; + + /// Merges the subrings of the selected column of `a` into the selected column of `res`. + /// + /// # Panics + /// + /// This method requires that all [VecZnx] of a have the same ring degree + /// and that a.n() * a.len() <= b.n() + fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; + + fn switch_degree(&self, r: &mut R, col_b: usize, a: &A, col_a: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; +} + +pub trait VecZnxScratch { + /// Returns the minimum number of bytes necessary for normalization. + fn vec_znx_normalize_tmp_bytes(&self) -> usize; +} + +impl VecZnxAlloc for Module { + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { + VecZnxOwned::new::(self.n(), cols, size) + } + + fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { + VecZnxOwned::bytes_of::(self.n(), cols, size) + } + + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned { + VecZnxOwned::new_from_bytes::(self.n(), cols, size, bytes) + } +} + +impl VecZnxOps for Module { + fn vec_znx_normalize(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + + let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes()); + + unsafe { + vec_znx::vec_znx_normalize_base2k( + self.ptr, + log_base2k as u64, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } + + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + + let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes()); + + unsafe { + vec_znx::vec_znx_normalize_base2k( + self.ptr, + log_base2k as u64, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } + + fn vec_znx_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_add_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } + + fn vec_znx_add_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: crate::ScalarZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + ) + } + } + + fn vec_znx_sub(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + B: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let b: VecZnx<&[u8]> = b.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_sub_scalar_inplace(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: ScalarZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let a: crate::ScalarZnx<&[u8]> = a.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, res_limb), + 1 as u64, + res.sl() as u64, + ) + } + } + + fn vec_znx_sub_ab_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_sub_ba_inplace(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + res.at_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + ) + } + } + + fn vec_znx_negate(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_negate( + self.ptr, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_negate_inplace(&self, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_negate( + self.ptr, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_rotate(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_rotate( + self.ptr, + k, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_rotate( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_automorphism(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + res.at_mut_ptr(res_col, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + where + A: VecZnxToMut, + { + let mut a: VecZnx<&mut [u8]> = a.to_mut(); + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert!( + k & 1 != 0, + "invalid galois element: must be odd but is {}", + k + ); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + ) + } + } + + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + + let (n_in, n_out) = (a.n(), res[0].to_mut().n()); + + let (mut buf, _) = scratch.tmp_vec_znx(self, 1, a.size()); + + debug_assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + res[1..].iter_mut().for_each(|bi| { + debug_assert_eq!( + bi.to_mut().n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + res.iter_mut().enumerate().for_each(|(i, bi)| { + if i == 0 { + self.switch_degree(bi, res_col, &a, a_col); + self.vec_znx_rotate(-1, &mut buf, 0, &a, a_col); + } else { + self.switch_degree(bi, res_col, &mut buf, a_col); + self.vec_znx_rotate_inplace(-1, &mut buf, a_col); + } + }) + } + + fn vec_znx_merge(&self, res: &mut R, res_col: usize, a: Vec, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_in, n_out) = (res.n(), a[0].to_ref().n()); + + debug_assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + a[1..].iter().for_each(|ai| { + debug_assert_eq!( + ai.to_ref().n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + a.iter().enumerate().for_each(|(_, ai)| { + self.switch_degree(&mut res, res_col, ai, a_col); + self.vec_znx_rotate_inplace(-1, &mut res, res_col); + }); + + self.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col); + } + + fn switch_degree(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let a: VecZnx<&[u8]> = a.to_ref(); + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + + let (n_in, n_out) = (a.n(), res.n()); + let (gap_in, gap_out): (usize, usize); + + if n_in > n_out { + (gap_in, gap_out) = (n_in / n_out, 1) + } else { + (gap_in, gap_out) = (1, n_out / n_in); + res.zero(); + } + + let size: usize = min(a.size(), res.size()); + + (0..size).for_each(|i| { + izip!( + a.at(a_col, i).iter().step_by(gap_in), + res.at_mut(res_col, i).iter_mut().step_by(gap_out) + ) + .for_each(|(x_in, x_out)| *x_out = *x_in); + }); + } +} + +impl VecZnxScratch for Module { + fn vec_znx_normalize_tmp_bytes(&self) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } + } +} diff --git a/base2k/src/vmp.rs b/base2k/src/vmp.rs deleted file mode 100644 index 7d6c26f..0000000 --- a/base2k/src/vmp.rs +++ /dev/null @@ -1,694 +0,0 @@ -use crate::ffi::vec_znx_big::vec_znx_big_t; -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::ffi::vmp::{self, vmp_pmat_t}; -use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement}; - -/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], -/// stored as a 3D matrix in the DFT domain in a single contiguous array. -/// Each row of the [VmpPMat] can be seen as a [VecZnxDft]. -/// -/// The backend array of [VmpPMat] is allocate in C, -/// and thus must be manually freed. -/// -/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat]. -/// See the trait [VmpPMatOps] for additional information. -pub struct VmpPMat { - /// Raw data, is empty if borrowing scratch space. - data: Vec, - /// Pointer to data. Can point to scratch space. - ptr: *mut u8, - /// The number of [VecZnxDft]. - rows: usize, - /// The number of cols in each [VecZnxDft]. - cols: usize, - /// The ring degree of each [VecZnxDft]. - n: usize, - /// The number of stacked [VmpPMat], must be a square. - size: usize, - /// The memory layout of the stacked [VmpPMat]. - layout: LAYOUT, - /// The backend fft or ntt. - backend: BACKEND, -} - -impl Infos for VmpPMat { - /// Returns the ring dimension of the [VmpPMat]. - fn n(&self) -> usize { - self.n - } - - fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - - fn size(&self) -> usize { - self.size - } - - fn layout(&self) -> LAYOUT { - self.layout - } - - /// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat] - fn rows(&self) -> usize { - self.rows - } - - /// Returns the number of cols of the [VmpPMat]. - /// The number of cols refers to the number of cols - /// of each [VecZnxDft]. - /// This method is equivalent to [Self::cols]. - fn cols(&self) -> usize { - self.cols - } -} - -impl VmpPMat { - pub fn as_ptr(&self) -> *const u8 { - self.ptr - } - - pub fn as_mut_ptr(&self) -> *mut u8 { - self.ptr - } - - pub fn borrowed(&self) -> bool { - self.data.len() == 0 - } - - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is rows * cols * n. - pub fn raw(&self) -> &[T] { - let ptr: *const T = self.ptr as *const T; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { &std::slice::from_raw_parts(ptr, len) } - } - - /// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// The length of the returned array is rows * cols * n. - pub fn raw_mut(&self) -> &mut [T] { - let ptr: *mut T = self.ptr as *mut T; - let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::(); - unsafe { std::slice::from_raw_parts_mut(ptr, len) } - } - - /// Returns a copy of the backend array at index (i, j) of the [VmpPMat]. - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - /// - /// # Arguments - /// - /// * `row`: row index (i). - /// * `col`: col index (j). - pub fn at(&self, row: usize, col: usize) -> Vec { - let mut res: Vec = alloc_aligned(self.n); - - if self.n < 8 { - res.copy_from_slice( - &self.raw::()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)], - ); - } else { - (0..self.n >> 3).for_each(|blk| { - res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); - }); - } - - res - } - - /// When using [`crate::FFT64`] as backend, `T` should be [f64]. - /// When using [`crate::NTT120`] as backend, `T` should be [i64]. - fn at_block(&self, row: usize, col: usize, blk: usize) -> &[T] { - let nrows: usize = self.rows(); - let ncols: usize = self.cols(); - if col == (ncols - 1) && (ncols & 1 == 1) { - &self.raw::()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..] - } else { - &self.raw::()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] - } - } - - fn backend(&self) -> BACKEND { - self.backend - } -} - -/// This trait implements methods for vector matrix product, -/// that is, multiplying a [VecZnx] with a [VmpPMat]. -pub trait VmpPMatOps { - fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize; - - /// Allocates a new [VmpPMat] with the given number of rows and columns. - /// - /// # Arguments - /// - /// * `rows`: number of rows (number of [VecZnxDft]). - /// * `cols`: number of cols (number of cols of each [VecZnxDft]). - fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat; - - /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. - /// - /// # Arguments - /// - /// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. - /// * `cols`: number of cols of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize; - - /// Prepares a [VmpPMat] from a contiguous array of [i64]. - /// The helper struct [Matrix3D] can be used to contruct and populate - /// the appropriate contiguous array. - /// - /// # Arguments - /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. - /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]); - - /// Prepares a [VmpPMat] from a vector of [VecZnx]. - /// - /// # Arguments - /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the vector of [VecZnx] to encode on the [VmpPMat]. - /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - /// - /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]); - - /// Prepares the ith-row of [VmpPMat] from a [VecZnx]. - /// - /// # Arguments - /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the vector of [VecZnx] to encode on the [VmpPMat]. - /// * `row_i`: the index of the row to prepare. - /// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - /// - /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); - - /// Extracts the ith-row of [VmpPMat] into a [VecZnxBig]. - /// - /// # Arguments - /// - /// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat]. - /// * `a`: [VmpPMat] on which the values are encoded. - /// * `row_i`: the index of the row to extract. - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize); - - /// Prepares the ith-row of [VmpPMat] from a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `b`: [VmpPMat] on which the values are encoded. - /// * `a`: the [VecZnxDft] to encode on the [VmpPMat]. - /// * `row_i`: the index of the row to prepare. - /// - /// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize); - - /// Extracts the ith-row of [VmpPMat] into a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat]. - /// * `a`: [VmpPMat] on which the values are encoded. - /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize); - - /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. - /// - /// # Arguments - /// - /// * `c_cols`: number of cols of the output [VecZnxDft]. - /// * `a_cols`: number of cols of the input [VecZnx]. - /// * `rows`: number of rows of the input [VmpPMat]. - /// * `cols`: number of cols of the input [VmpPMat]. - fn vmp_apply_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize; - - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. - /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. - /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnx] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]); - - /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. - /// - /// # Arguments - /// - /// * `c_cols`: number of cols of the output [VecZnxDft]. - /// * `a_cols`: number of cols of the input [VecZnxDft]. - /// * `rows`: number of rows of the input [VmpPMat]. - /// * `cols`: number of cols of the input [VmpPMat]. - fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize; - - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. - /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. - /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it. - /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. - /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. - /// * `a`: the left operand [VecZnxDft] of the vector matrix product. - /// * `b`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]); - - /// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. - /// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - /// - /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] - /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) - /// and each vector a [VecZnxDft] (row) of the [VmpPMat]. - /// - /// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and - /// `j` cols, the output is a [VecZnx] of `j` cols. - /// - /// If there is a mismatch between the dimensions the largest valid ones are used. - /// - /// ```text - /// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p| - /// |h i j| - /// |k l m| - /// ``` - /// where each element is a [VecZnxDft]. - /// - /// # Arguments - /// - /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. - /// * `a`: the right operand [VmpPMat] of the vector matrix product. - /// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]); -} - -impl VmpPMatOps for Module { - fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, cols as u64) as usize * size } - } - - fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat { - let mut data: Vec = alloc_aligned::(self.bytes_of_vmp_pmat(size, rows, cols)); - let ptr: *mut u8 = data.as_mut_ptr(); - VmpPMat { - data: data, - ptr: ptr, - n: self.n(), - size: size, - layout: LAYOUT::COL, - cols: cols, - rows: rows, - backend: self.backend(), - } - } - - fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize { - unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, cols as u64) as usize } - } - - fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) { - debug_assert_eq!(a.len(), b.n * b.rows * b.cols); - debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_prepare_contiguous( - self.ptr, - b.as_mut_ptr() as *mut vmp_pmat_t, - a.as_ptr(), - b.rows() as u64, - b.cols() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], tmp_bytes: &mut [u8]) { - let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect(); - #[cfg(debug_assertions)] - { - debug_assert_eq!(a.len(), b.rows); - a.iter().for_each(|ai| { - debug_assert_eq!(ai.len(), b.n * b.cols); - }); - debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_prepare_dblptr( - self.ptr, - b.as_mut_ptr() as *mut vmp_pmat_t, - ptrs.as_ptr(), - b.rows() as u64, - b.cols() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { - #[cfg(debug_assertions)] - { - assert_eq!(a.len(), b.cols() * self.n()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols())); - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_prepare_row( - self.ptr, - b.as_mut_ptr() as *mut vmp_pmat_t, - a.as_ptr(), - row_i as u64, - b.rows() as u64, - b.cols() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } - - fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.cols(), b.cols()); - } - unsafe { - vmp::vmp_extract_row( - self.ptr, - b.ptr as *mut vec_znx_big_t, - a.as_ptr() as *const vmp_pmat_t, - row_i as u64, - a.rows() as u64, - a.cols() as u64, - ); - } - } - - fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.cols(), b.cols()); - } - unsafe { - vmp::vmp_prepare_row_dft( - self.ptr, - b.as_mut_ptr() as *mut vmp_pmat_t, - a.ptr as *const vec_znx_dft_t, - row_i as u64, - b.rows() as u64, - b.cols() as u64, - ); - } - } - - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), b.n()); - assert_eq!(a.cols(), b.cols()); - } - unsafe { - vmp::vmp_extract_row_dft( - self.ptr, - b.ptr as *mut vec_znx_dft_t, - a.as_ptr() as *const vmp_pmat_t, - row_i as u64, - a.rows() as u64, - a.cols() as u64, - ); - } - } - - fn vmp_apply_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize { - unsafe { - vmp::vmp_apply_dft_tmp_bytes( - self.ptr, - res_cols as u64, - a_cols as u64, - gct_rows as u64, - gct_cols as u64, - ) as usize - } - } - - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft( - self.ptr, - c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, - a.as_ptr(), - a.cols() as u64, - a.n() as u64, - b.as_ptr() as *const vmp_pmat_t, - b.rows() as u64, - b.cols() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_add( - self.ptr, - c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, - a.as_ptr(), - a.cols() as u64, - a.n() as u64, - b.as_ptr() as *const vmp_pmat_t, - b.rows() as u64, - b.cols() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize { - unsafe { - vmp::vmp_apply_dft_to_dft_tmp_bytes( - self.ptr, - res_cols as u64, - a_cols as u64, - gct_rows as u64, - gct_cols as u64, - ) as usize - } - } - - fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, - a.ptr as *const vec_znx_dft_t, - a.cols() as u64, - b.as_ptr() as *const vmp_pmat_t, - b.rows() as u64, - b.cols() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft_add( - self.ptr, - c.ptr as *mut vec_znx_dft_t, - c.cols() as u64, - a.ptr as *const vec_znx_dft_t, - a.cols() as u64, - b.as_ptr() as *const vmp_pmat_t, - b.rows() as u64, - b.cols() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } - - fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - } - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, - b.ptr as *mut vec_znx_dft_t, - b.cols() as u64, - a.as_ptr() as *const vmp_pmat_t, - a.rows() as u64, - a.cols() as u64, - tmp_bytes.as_mut_ptr(), - ) - } - } -} - -#[cfg(test)] -mod tests { - use crate::{ - Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned, - }; - use sampling::source::Source; - - #[test] - fn vmp_prepare_row_dft() { - let module: Module = Module::new(32, crate::BACKEND::FFT64); - let vpmat_rows: usize = 4; - let vpmat_cols: usize = 5; - let log_base2k: usize = 8; - let mut a: VecZnx = module.new_vec_znx(1, vpmat_cols); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols); - let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols); - let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols); - let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols); - let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols); - - let mut tmp_bytes: Vec = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols)); - - for row_i in 0..vpmat_rows { - let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source); - module.vec_znx_dft(&mut a_dft, &a); - module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); - - // Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft) - module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); - assert_eq!(vmpmat_0.raw::(), vmpmat_1.raw::()); - - // Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i); - assert_eq!(a_dft.raw::(&module), b_dft.raw::(&module)); - - // Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big) - module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); - module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes); - assert_eq!(a_big.raw::(&module), b_big.raw::(&module)); - } - - module.free(); - } -} diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs new file mode 100644 index 0000000..f618446 --- /dev/null +++ b/base2k/src/znx_base.rs @@ -0,0 +1,199 @@ +use itertools::izip; +use rand_distr::num_traits::Zero; + +pub trait ZnxInfos { + /// Returns the ring degree of the polynomials. + fn n(&self) -> usize; + + /// Returns the base two logarithm of the ring dimension of the polynomials. + fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } + + /// Returns the number of rows. + fn rows(&self) -> usize; + + /// Returns the number of polynomials in each row. + fn cols(&self) -> usize; + + /// Returns the number of size per polynomial. + fn size(&self) -> usize; + + /// Returns the total number of small polynomials. + fn poly_count(&self) -> usize { + self.rows() * self.cols() * self.size() + } +} + +pub trait ZnxSliceSize { + /// Returns the slice size, which is the offset between + /// two size of the same column. + fn sl(&self) -> usize; +} + +pub trait DataView { + type D; + fn data(&self) -> &Self::D; +} + +pub trait DataViewMut: DataView { + fn data_mut(&mut self) -> &mut Self::D; +} + +pub trait ZnxView: ZnxInfos + DataView> { + type Scalar: Copy; + + /// Returns a non-mutable pointer to the underlying coefficients array. + fn as_ptr(&self) -> *const Self::Scalar { + self.data().as_ref().as_ptr() as *const Self::Scalar + } + + /// Returns a non-mutable reference to the entire underlying coefficient array. + fn raw(&self) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } + } + + /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. + fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.size()); + } + let offset: usize = self.n() * (j * self.cols() + i); + unsafe { self.as_ptr().add(offset) } + } + + /// Returns non-mutable reference to the (i, j)-th small polynomial. + fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { + unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } + } +} + +pub trait ZnxViewMut: ZnxView + DataViewMut> { + /// Returns a mutable pointer to the underlying coefficients array. + fn as_mut_ptr(&mut self) -> *mut Self::Scalar { + self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar + } + + /// Returns a mutable reference to the entire underlying coefficient array. + fn raw_mut(&mut self) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } + } + + /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. + fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { + #[cfg(debug_assertions)] + { + assert!(i < self.cols()); + assert!(j < self.size()); + } + let offset: usize = self.n() * (j * self.cols() + i); + unsafe { self.as_mut_ptr().add(offset) } + } + + /// Returns mutable reference to the (i, j)-th small polynomial. + fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { + unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } + } +} + +//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known +impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} + +pub trait ZnxZero: ZnxViewMut + ZnxSliceSize +where + Self: Sized, +{ + fn zero(&mut self) { + unsafe { + std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count()); + } + } + + fn zero_at(&mut self, i: usize, j: usize) { + unsafe { + std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n()); + } + } +} + +// Blanket implementations +impl ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does + +use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; + +use crate::Scratch; +pub trait Integer: + Copy + + Default + + PartialEq + + PartialOrd + + Add + + Sub + + Mul + + Div + + Neg + + Shl + + Shr + + AddAssign +{ + const BITS: u32; +} + +impl Integer for i64 { + const BITS: u32 = 64; +} + +impl Integer for i128 { + const BITS: u32 = 128; +} + +//(Jay)Note: `rsh` impl. ignores the column +pub fn rsh(k: usize, log_base2k: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch) +where + V::Scalar: From + Integer + Zero, +{ + let n: usize = a.n(); + let _size: usize = a.size(); + let cols: usize = a.cols(); + + let size: usize = a.size(); + let steps: usize = k / log_base2k; + + a.raw_mut().rotate_right(n * steps * cols); + (0..cols).for_each(|i| { + (0..steps).for_each(|j| { + a.zero_at(i, j); + }) + }); + + let k_rem: usize = k % log_base2k; + + if k_rem != 0 { + let (carry, _) = scratch.tmp_slice::(rsh_tmp_bytes::(n)); + + unsafe { + std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); + } + + let log_base2k_t = V::Scalar::from(log_base2k); + let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem); + let k_rem_t = V::Scalar::from(k_rem); + + (0..cols).for_each(|i| { + (steps..size).for_each(|j| { + izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| { + *xi += *ci << log_base2k_t; + *ci = (*xi << shift) >> shift; + *xi = (*xi - *ci) >> k_rem_t; + }); + }); + carry.iter_mut().for_each(|r| *r = V::Scalar::zero()); + }) + } +} + +pub fn rsh_tmp_bytes(n: usize) -> usize { + n * std::mem::size_of::() +} diff --git a/rlwe/Cargo.toml b/core/Cargo.toml similarity index 75% rename from rlwe/Cargo.toml rename to core/Cargo.toml index a8b8207..a54bd5a 100644 --- a/rlwe/Cargo.toml +++ b/core/Cargo.toml @@ -1,5 +1,3 @@ -cargo-features = ["edition2024"] - [package] name = "rlwe" version = "0.1.0" @@ -14,5 +12,9 @@ rand_distr = {workspace = true} itertools = {workspace = true} [[bench]] -name = "gadget_product" +name = "external_product_glwe_fft64" +harness = false + +[[bench]] +name = "keyswitch_glwe_fft64" harness = false \ No newline at end of file diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs new file mode 100644 index 0000000..1739211 --- /dev/null +++ b/core/benches/external_product_glwe_fft64.rs @@ -0,0 +1,202 @@ +use base2k::{FFT64, Module, ScalarZnxAlloc, ScratchOwned}; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rlwe::{ + elem::Infos, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + keys::{SecretKey, SecretKeyFourier}, +}; +use sampling::source::Source; + +fn bench_external_product_glwe_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("external_product_glwe_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_ct_in: usize, + k_ct_out: usize, + k_ggsw: usize, + rank: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_ct_in: usize = p.k_ct_in; + let k_ct_out: usize = p.k_ct_out; + let k_ggsw: usize = p.k_ggsw; + let rank: usize = p.rank; + + let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); + let pt_rgsw: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::external_product_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + rank, + ), + ); + + let mut source_xs = Source::new([0u8; 32]); + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + move || { + ct_rlwe_out.external_product( + black_box(&module), + black_box(&ct_rlwe_in), + black_box(&ct_rgsw), + black_box(scratch.borrow()), + ); + } + } + + let params_set: Vec = vec![Params { + log_n: 10, + basek: 7, + k_ct_in: 27, + k_ct_out: 27, + k_ggsw: 27, + rank: 1, + }]; + + for params in params_set { + let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("external_product_glwe_inplace_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_ct: usize, + k_ggsw: usize, + rank: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_glwe: usize = p.k_ct; + let k_ggsw: usize = p.k_ggsw; + let rank: usize = p.rank; + + let rows: usize = (p.k_ct + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_glwe, rank); + let pt_rgsw: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank), + ); + + let mut source_xs = Source::new([0u8; 32]); + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rlwe.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + move || { + let scratch_borrow = scratch.borrow(); + (0..687).for_each(|_| { + ct_rlwe.external_product_inplace( + black_box(&module), + black_box(&ct_rgsw), + black_box(scratch_borrow), + ); + }); + } + } + + let params_set: Vec = vec![Params { + log_n: 12, + basek: 18, + k_ct: 54, + k_ggsw: 54, + rank: 1, + }]; + + for params in params_set { + let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_INPLACE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_external_product_glwe_fft64, + bench_external_product_glwe_inplace_fft64 +); +criterion_main!(benches); diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs new file mode 100644 index 0000000..1c1b7f8 --- /dev/null +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -0,0 +1,211 @@ +use base2k::{FFT64, Module, ScratchOwned}; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rlwe::{ + elem::Infos, + glwe_ciphertext::GLWECiphertext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; +use sampling::source::Source; + +fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("keyswitch_glwe_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_ct_in: usize, + k_ct_out: usize, + k_ksk: usize, + rank_in: usize, + rank_out: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_rlwe_in: usize = p.k_ct_in; + let k_rlwe_out: usize = p.k_ct_out; + let k_grlwe: usize = p.k_ksk; + let rank_in: usize = p.rank_in; + let rank_out: usize = p.rank_out; + + let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_grlwe, rows, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_out, rank_out); + + let mut scratch = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size()) + | GLWECiphertext::keyswitch_scratch_space( + &module, + ct_out.size(), + ct_in.size(), + ksk.size(), + rank_in, + rank_out, + ), + ); + + let mut source_xs = Source::new([0u8; 32]); + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa = Source::new([0u8; 32]); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_zero_sk( + &module, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + move || { + ct_out.keyswitch( + black_box(&module), + black_box(&ct_in), + black_box(&ksk), + black_box(scratch.borrow()), + ); + } + } + + let params_set: Vec = vec![Params { + log_n: 16, + basek: 50, + k_ct_in: 1250, + k_ct_out: 1250, + k_ksk: 1250 + 66, + rank_in: 1, + rank_out: 1, + }]; + + for params in params_set { + let id = BenchmarkId::new("KEYSWITCH_GLWE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("keyswitch_glwe_inplace_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_ct: usize, + k_ksk: usize, + rank: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_ct: usize = p.k_ct; + let k_ksk: usize = p.k_ksk; + let rank: usize = p.rank; + + let rows: usize = (p.k_ct + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + + let mut scratch = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct.size(), ksk.size(), rank), + ); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.encrypt_zero_sk( + &module, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + move || { + ct.keyswitch_inplace( + black_box(&module), + black_box(&ksk), + black_box(scratch.borrow()), + ); + } + } + + let params_set: Vec = vec![Params { + log_n: 9, + basek: 18, + k_ct: 27, + k_ksk: 27, + rank: 1, + }]; + + for params in params_set { + let id = BenchmarkId::new("KEYSWITCH_GLWE_INPLACE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_keyswitch_glwe_fft64, + bench_keyswitch_glwe_inplace_fft64 +); +criterion_main!(benches); diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs new file mode 100644 index 0000000..8dca7ec --- /dev/null +++ b/core/src/automorphism.rs @@ -0,0 +1,386 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, + ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, + ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +pub struct AutomorphismKey { + pub(crate) key: GLWESwitchingKey, + pub(crate) p: i64, +} + +impl AutomorphismKey, FFT64> { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + AutomorphismKey { + key: GLWESwitchingKey::new(module, basek, k, rows, rank, rank), + p: 0, + } + } +} + +impl Infos for AutomorphismKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.key.inner() + } + + fn basek(&self) -> usize { + self.key.basek() + } + + fn k(&self) -> usize { + self.key.k() + } +} + +impl AutomorphismKey { + pub fn p(&self) -> i64 { + self.p + } + + pub fn rank(&self) -> usize { + self.key.rank() + } + + pub fn rank_in(&self) -> usize { + self.key.rank_in() + } + + pub fn rank_out(&self) -> usize { + self.key.rank_out() + } +} + +impl MatZnxDftToMut for AutomorphismKey +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.key.to_mut() + } +} + +impl MatZnxDftToRef for AutomorphismKey +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.key.to_ref() + } +} + +impl GetRow for AutomorphismKey +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) + where + R: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for AutomorphismKey +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) + where + R: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl AutomorphismKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + } + + pub fn encrypt_pk_scratch_space(module: &Module, rank: usize, pk_size: usize) -> usize { + GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size) + } + + pub fn keyswitch_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + GLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size) + } + + pub fn automorphism_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + rank: usize, + ) -> usize { + let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size); + let tmp_idft: usize = module.bytes_of_vec_znx_big(rank + 1, out_size); + let idft: usize = module.vec_znx_idft_tmp_bytes(); + let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, out_size, rank, ksk_size); + tmp_dft + tmp_idft + idft + keyswitch + } + + pub fn automorphism_inplace_scratch_space(module: &Module, out_size: usize, ksk_size: usize, rank: usize) -> usize { + AutomorphismKey::automorphism_scratch_space(module, out_size, out_size, ksk_size, rank) + } + + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + GLWESwitchingKey::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank) + } +} + +impl AutomorphismKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + p: i64, + sk: &SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.rank_out(), self.rank_in()); + assert_eq!(sk.rank(), self.rank()); + } + + let (sk_out_dft_data, scratch_1) = scratch.tmp_scalar_znx_dft(module, sk.rank()); + + let mut sk_out_dft: SecretKeyFourier<&mut [u8], FFT64> = SecretKeyFourier { + data: sk_out_dft_data, + dist: sk.dist, + }; + + { + (0..self.rank()).for_each(|i| { + let (mut sk_inv_auto, _) = scratch_1.tmp_scalar_znx(module, 1); + module.scalar_znx_automorphism(module.galois_element_inv(p), &mut sk_inv_auto, 0, sk, i); + module.svp_prepare(&mut sk_out_dft, i, &sk_inv_auto, 0); + }); + } + + self.key.encrypt_sk( + module, + &sk, + &sk_out_dft, + source_xa, + source_xe, + sigma, + scratch_1, + ); + + self.p = p; + } +} + +impl AutomorphismKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn automorphism( + &mut self, + module: &Module, + lhs: &AutomorphismKey, + rhs: &AutomorphismKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let cols_out: usize = rhs.rank_out() + 1; + + let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size()); + + let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + // Extracts relevant row + lhs.get_row(module, row_j, col_i, &mut tmp_dft); + + // Get a VecZnxBig from scratch space + let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size()); + + // Switches input outside of DFT + (0..cols_out).for_each(|i| { + module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); + }); + + // Consumes to small vec znx + let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small(); + + // Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + (0..cols_out).for_each(|i| { + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft_small_data, i); + }); + + // Wraps into ciphertext + let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: tmp_idft_small_data, + basek: self.basek(), + k: self.k(), + }; + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2); + + // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) + // and switches back to DFT domain + (0..self.rank_out() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft, i); + module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i); + }); + + // Sets back the relevant row + self.set_row(module, row_j, col_i, &tmp_dft); + }); + }); + + tmp_dft.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_dft); + }); + }); + + self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); + } + + pub fn automorphism_inplace( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut AutomorphismKey = self as *mut AutomorphismKey; + self.automorphism(&module, &*self_ptr, rhs, scratch); + } + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &AutomorphismKey, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + self.key.keyswitch(module, &lhs.key, rhs, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + self.key.keyswitch_inplace(module, &rhs, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &AutomorphismKey, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + self.key.external_product(module, &lhs.key, rhs, scratch); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + self.key.external_product_inplace(module, rhs, scratch); + } +} diff --git a/core/src/elem.rs b/core/src/elem.rs new file mode 100644 index 0000000..66cb1d0 --- /dev/null +++ b/core/src/elem.rs @@ -0,0 +1,59 @@ +use base2k::{Backend, Module, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; + +use crate::utils::derive_size; + +pub trait Infos { + type Inner: ZnxInfos; + + fn inner(&self) -> &Self::Inner; + + /// Returns the ring degree of the polynomials. + fn n(&self) -> usize { + self.inner().n() + } + + /// Returns the base two logarithm of the ring dimension of the polynomials. + fn log_n(&self) -> usize { + self.inner().log_n() + } + + /// Returns the number of rows. + fn rows(&self) -> usize { + self.inner().rows() + } + + /// Returns the number of polynomials in each row. + fn cols(&self) -> usize { + self.inner().cols() + } + + /// Returns the number of size per polynomial. + fn size(&self) -> usize { + let size: usize = self.inner().size(); + debug_assert_eq!(size, derive_size(self.basek(), self.k())); + size + } + + /// Returns the total number of small polynomials. + fn poly_count(&self) -> usize { + self.rows() * self.cols() * self.size() + } + + /// Returns the base 2 logarithm of the ciphertext base. + fn basek(&self) -> usize; + + /// Returns the bit precision of the ciphertext. + fn k(&self) -> usize; +} + +pub trait GetRow { + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) + where + R: VecZnxDftToMut; +} + +pub trait SetRow { + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) + where + R: VecZnxDftToRef; +} diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs new file mode 100644 index 0000000..f8983c8 --- /dev/null +++ b/core/src/gglwe_ciphertext.rs @@ -0,0 +1,211 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, + ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, + ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, + utils::derive_size, +}; + +pub struct GGLWECiphertext { + pub(crate) data: MatZnxDft, + pub(crate) basek: usize, + pub(crate) k: usize, +} + +impl GGLWECiphertext, B> { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)), + basek: basek, + k, + } + } +} + +impl Infos for GGLWECiphertext { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGLWECiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } + + pub fn rank_in(&self) -> usize { + self.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.data.cols_out() - 1 + } +} + +impl MatZnxDftToMut for GGLWECiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for GGLWECiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GGLWECiphertext, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GLWECiphertext::encrypt_sk_scratch_space(module, size) + + module.bytes_of_vec_znx(rank + 1, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(rank + 1, size) + } + + pub fn encrypt_pk_scratch_space(_module: &Module, _rank: usize, _pk_size: usize) -> usize { + unimplemented!() + } +} + +impl GGLWECiphertext +where + MatZnxDft: MatZnxDftToMut + ZnxInfos, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank_in(), pt.cols()); + assert_eq!(self.rank_out(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + assert_eq!(pt.n(), module.n()); + } + + let rows: usize = self.rows(); + let size: usize = self.size(); + let basek: usize = self.basek(); + let k: usize = self.k(); + + let cols_in: usize = self.rank_in(); + let cols_out: usize = self.rank_out() + 1; + + let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols_out, size); + let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols_out, size); + + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { + data: tmp_znx_pt, + basek, + k, + }; + + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { + data: tmp_znx_ct, + basek, + k, + }; + + let mut vec_znx_ct_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier { + data: tmp_znx_dft_ct, + basek, + k, + }; + + // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns + // + // Example for ksk rank 2 to rank 3: + // + // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) + // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) + // + // Example ksk rank 2 to rank 1 + // + // (-(a*s) + s0, a) + // (-(b*s) + s1, b) + (0..cols_in).for_each(|col_i| { + (0..rows).for_each(|row_i| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + vec_znx_pt.data.zero(); // zeroes for next iteration + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th + module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); + + // rlwe encrypt of vec_znx_pt into vec_znx_ct + vec_znx_ct.encrypt_sk( + module, + &vec_znx_pt, + sk_dft, + source_xa, + source_xe, + sigma, + scratch_3, + ); + + // Switch vec_znx_ct into DFT domain + vec_znx_ct.dft(module, &mut vec_znx_ct_dft); + + // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft + module.vmp_prepare_row(self, row_i, col_i, &vec_znx_ct_dft); + }); + }); + } +} + +impl GetRow for GGLWECiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) + where + R: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for GGLWECiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) + where + R: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs new file mode 100644 index 0000000..8955b7a --- /dev/null +++ b/core/src/ggsw_ciphertext.rs @@ -0,0 +1,684 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, + VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, + VecZnxToRef, ZnxInfos, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + automorphism::AutomorphismKey, + elem::{GetRow, Infos, SetRow}, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, + keyswitch_key::GLWESwitchingKey, + tensor_key::TensorKey, + utils::derive_size, +}; + +pub struct GGSWCiphertext { + pub data: MatZnxDft, + pub basek: usize, + pub k: usize, +} + +impl GGSWCiphertext, B> { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)), + basek: basek, + k: k, + } + } +} + +impl Infos for GGSWCiphertext { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGSWCiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } +} + +impl MatZnxDftToMut for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GGSWCiphertext, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GLWECiphertext::encrypt_sk_scratch_space(module, size) + + module.bytes_of_vec_znx(rank + 1, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(rank + 1, size) + } + + pub(crate) fn expand_row_scratch_space( + module: &Module, + self_size: usize, + tensor_key_size: usize, + rank: usize, + ) -> usize { + let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tensor_key_size); + let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size); + let vmp: usize = + tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tensor_key_size); + let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tensor_key_size); + let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); + tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm)) + } + + pub(crate) fn keyswitch_internal_col0_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + rank: usize, + ) -> usize { + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, ksk_size) + + module.bytes_of_vec_znx_dft(rank + 1, in_size) + } + + pub fn keyswitch_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ksk_size: usize, + tensor_key_size: usize, + rank: usize, + ) -> usize { + let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size); + let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, ksk_size, rank); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank); + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + res_znx + ci_dft + (ks | expand_rows | res_dft) + } + + pub fn keyswitch_inplace_scratch_space( + module: &Module, + out_size: usize, + ksk_size: usize, + tensor_key_size: usize, + rank: usize, + ) -> usize { + GGSWCiphertext::keyswitch_scratch_space(module, out_size, out_size, ksk_size, tensor_key_size, rank) + } + + pub fn automorphism_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + auto_key_size: usize, + tensor_key_size: usize, + rank: usize, + ) -> usize { + GGSWCiphertext::keyswitch_scratch_space( + module, + out_size, + in_size, + auto_key_size, + tensor_key_size, + rank, + ) + } + + pub fn automorphism_inplace_scratch_space( + module: &Module, + out_size: usize, + auto_key_size: usize, + tensor_key_size: usize, + rank: usize, + ) -> usize { + GGSWCiphertext::automorphism_scratch_space( + module, + out_size, + out_size, + auto_key_size, + tensor_key_size, + rank, + ) + } + + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size); + let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank); + tmp_in + tmp_out + ggsw + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank); + tmp + ggsw + } +} + +impl GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let size: usize = self.size(); + let basek: usize = self.basek(); + let k: usize = self.k(); + let cols: usize = self.rank() + 1; + + let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, cols, size); + + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { + data: tmp_znx_pt, + basek: basek, + k: k, + }; + + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { + data: tmp_znx_ct, + basek: basek, + k, + }; + + (0..self.rows()).for_each(|row_i| { + vec_znx_pt.data.zero(); + + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); + module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2); + + (0..cols).for_each(|col_j| { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + + vec_znx_ct.encrypt_sk_private( + module, + Some((&vec_znx_pt, col_j)), + sk_dft, + source_xa, + source_xe, + sigma, + scrach_2, + ); + + // Switch vec_znx_ct into DFT domain + { + let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, cols, size); + + (0..cols).for_each(|i| { + module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct, i); + }); + + self.set_row(module, row_i, col_j, &vec_znx_dft_ct); + } + }); + }); + } + + pub(crate) fn expand_row( + &mut self, + module: &Module, + col_j: usize, + res: &mut R, + ci_dft: &VecZnxDft, + tsk: &TensorKey, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + let cols: usize = self.rank() + 1; + + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many rows and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + + let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size()); + { + let (mut tmp_dft_col_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, self.size()); + + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + (1..cols).for_each(|col_i| { + // Extracts a[i] and multipies with Enc(s[i]s[j]) + tmp_dft_col_data.extract_column(0, ci_dft, col_i); + + if col_i == 1 { + module.vmp_apply( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j]) + scratch2, + ); + } else { + module.vmp_apply_add( + &mut tmp_dft_i, + &tmp_dft_col_data, + tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j]) + scratch2, + ); + } + }); + } + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, ci_dft, 0); + let (mut tmp_idft, scratch2) = scratch1.tmp_vec_znx_big(module, 1, tsk.size()); + (0..cols).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i); + module.vec_znx_big_normalize(self.basek(), res, i, &tmp_idft, 0, scratch2); + }); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + ksk: &GLWESwitchingKey, + tsk: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + let cols: usize = self.rank() + 1; + + let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); + let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, + basek: self.basek(), + k: self.k(), + }; + + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); + + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + lhs.keyswitch_internal_col0(module, row_i, &mut res, ksk, scratch2); + + // Isolates DFT(a[i]) + (0..cols).for_each(|col_i| { + module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i); + }); + + self.set_row(module, row_i, 0, &ci_dft); + + // Generates + // + // col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) + (1..cols).for_each(|col_j| { + self.expand_row(module, col_j, &mut res, &ci_dft, tsk, scratch2); + + let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); + (0..cols).for_each(|i| { + module.vec_znx_dft(&mut res_dft, i, &res, i); + }); + + self.set_row(module, row_i, col_j, &res_dft); + }) + }) + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + ksk: &GLWESwitchingKey, + tsk: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; + self.keyswitch(module, &*self_ptr, ksk, tsk, scratch); + } + } + + pub fn automorphism( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + auto_key: &AutomorphismKey, + tensor_key: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + self.rank(), + auto_key.rank(), + "ggsw_in rank: {} != auto_key rank: {}", + self.rank(), + auto_key.rank() + ); + assert_eq!( + self.rank(), + tensor_key.rank(), + "ggsw_in rank: {} != tensor_key rank: {}", + self.rank(), + tensor_key.rank() + ); + }; + + let cols: usize = self.rank() + 1; + + let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); + let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, + basek: self.basek(), + k: self.k(), + }; + + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); + + // Keyswitch the j-th row of the col 0 + (0..lhs.rows()).for_each(|row_i| { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) + lhs.keyswitch_internal_col0(module, row_i, &mut res, &auto_key.key, scratch2); + + // Isolates DFT(AUTO(a[i])) + (0..cols).for_each(|col_i| { + // (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2) + module.vec_znx_automorphism_inplace(auto_key.p(), &mut res, col_i); + module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i); + }); + + self.set_row(module, row_i, 0, &ci_dft); + + // Generates + // + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + pi(M[i]), b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i])) + (1..cols).for_each(|col_j| { + self.expand_row(module, col_j, &mut res, &ci_dft, tensor_key, scratch2); + + let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); + (0..cols).for_each(|i| { + module.vec_znx_dft(&mut res_dft, i, &res, i); + }); + + self.set_row(module, row_i, col_j, &res_dft); + }) + }) + } + + pub fn automorphism_inplace( + &mut self, + module: &Module, + auto_key: &AutomorphismKey, + tensor_key: &TensorKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GGSWCiphertext = self as *mut GGSWCiphertext; + self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch); + } + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + lhs.rank(), + "ggsw_out rank: {} != ggsw_in rank: {}", + self.rank(), + lhs.rank() + ); + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_in rank: {} != ggsw_apply rank: {}", + self.rank(), + rhs.rank() + ); + } + + let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size()); + + let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + + let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_out_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank() + 1).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.external_product(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank() + 1).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank(), + rhs.rank(), + "ggsw_out rank: {} != ggsw_apply: {}", + self.rank(), + rhs.rank() + ); + } + + let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + + let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank() + 1).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.external_product_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); + } +} + +impl GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + pub(crate) fn keyswitch_internal_col0( + &self, + module: &Module, + row_i: usize, + res: &mut GLWECiphertext, + ksk: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), ksk.rank()); + assert_eq!(res.rank(), ksk.rank()); + } + + let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); + let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_in_data, + basek: self.basek(), + k: self.k(), + }; + self.get_row(module, row_i, 0, &mut tmp_dft_in); + res.keyswitch_from_fourier(module, &tmp_dft_in, ksk, scratch2); + } +} + +impl GetRow for GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) + where + R: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) + where + R: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs new file mode 100644 index 0000000..e319d21 --- /dev/null +++ b/core/src/glwe_ciphertext.rs @@ -0,0 +1,696 @@ +use base2k::{ + AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc, + ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, + VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, + VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + SIX_SIGMA, + automorphism::AutomorphismKey, + elem::Infos, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + utils::derive_size, +}; + +pub struct GLWECiphertext { + pub data: VecZnx, + pub basek: usize, + pub k: usize, +} + +impl GLWECiphertext> { + pub fn new(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx(rank + 1, derive_size(basek, k)), + basek, + k, + } + } +} + +impl Infos for GLWECiphertext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWECiphertext { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl VecZnxToMut for GLWECiphertext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + #[allow(dead_code)] + pub(crate) fn dft(&self, module: &Module, res: &mut GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToMut + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_dft(res, i, self, i); + }) + } +} + +impl GLWECiphertext> { + pub fn encrypt_sk_scratch_space(module: &Module, ct_size: usize) -> usize { + module.vec_znx_big_normalize_tmp_bytes() + + module.bytes_of_vec_znx_dft(1, ct_size) + + module.bytes_of_vec_znx_big(1, ct_size) + } + pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { + ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_space(module: &Module, ct_size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, ct_size)) + + module.bytes_of_vec_znx_big(1, ct_size) + } + + pub fn keyswitch_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, + ) -> usize { + let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) + + module.bytes_of_vec_znx_dft(in_rank, in_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + return res_dft + (vmp | normalize); + } + + pub fn keyswitch_from_fourier_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, + ) -> usize { + let res_dft = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size); + + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) + + module.bytes_of_vec_znx_dft(in_rank, in_size); + + let norm: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | norm) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) + } + + pub fn automorphism_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + autokey_size: usize, + ) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, in_size, out_rank, autokey_size) + } + + pub fn automorphism_inplace_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + autokey_size: usize, + ) -> usize { + GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, autokey_size) + } + + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + ggsw_size: usize, + ) -> usize { + let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ggsw_size); + let vmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, in_size) + + module.vmp_apply_tmp_bytes( + out_size, + in_size, + in_size, // rows + out_rank + 1, // cols in + out_rank + 1, // cols out + ggsw_size, + ); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | normalize) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + ggsw_size: usize, + ) -> usize { + GLWECiphertext::external_product_scratch_space(module, out_size, out_rank, out_size, ggsw_size) + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + self.encrypt_sk_private( + module, + Some((pt, 0)), + sk_dft, + source_xa, + source_xe, + sigma, + scratch, + ); + } + + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + self.encrypt_sk_private(module, None, sk_dft, source_xa, source_xe, sigma, scratch); + } + + pub fn encrypt_pk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + self.encrypt_pk_private( + module, + Some((pt, 0)), + pk, + source_xu, + source_xe, + sigma, + scratch, + ); + } + + pub fn encrypt_zero_pk( + &mut self, + module: &Module, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + { + self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch); + } + + pub fn automorphism( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + self.keyswitch(module, lhs, &rhs.key, scratch); + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), self, i); + }) + } + + pub fn automorphism_inplace( + &mut self, + module: &Module, + rhs: &AutomorphismKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + self.keyswitch_inplace(module, &rhs.key, scratch); + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), self, i); + }) + } + + pub(crate) fn keyswitch_from_fourier( + &mut self, + module: &Module, + lhs: &GLWECiphertextFourier, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_from_fourier_scratch_space( + module, + self.size(), + self.rank(), + lhs.size(), + lhs.rank(), + rhs.size(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + // Buffer of the result of VMP in DFT + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + // Applies VMP + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); + } + + module.vec_znx_dft_add_inplace(&mut res_dft, 0, lhs, 0); + + // Switches result of VMP outside of DFT + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); + }); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertext::keyswitch_scratch_space( + module, + self.size(), + self.rank(), + lhs.size(), + lhs.rank(), + rhs.size(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); + } + + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0); + + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); + }); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + } + + let cols: usize = rhs.rank() + 1; + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise + + { + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); + (0..cols).for_each(|col_i| { + module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i); + }); + module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2); + } + + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); + }); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GLWECiphertext = self as *mut GLWECiphertext; + self.external_product(&module, &*self_ptr, rhs, scratch); + } + } + + pub(crate) fn encrypt_sk_private( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(sk_dft.n(), module.n()); + assert_eq!(self.n(), module.n()); + if let Some((pt, col)) = pt { + assert_eq!(pt.n(), module.n()); + assert!(col < self.rank() + 1); + } + } + + let log_base2k: usize = self.basek(); + let log_k: usize = self.k(); + let size: usize = self.size(); + let cols: usize = self.rank() + 1; + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + c0_big.zero(); + + { + // c[i] = uniform + // c[0] -= c[i] * s[i], + (1..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size); + + // c[i] = uniform + self.data.fill_uniform(log_base2k, i, size, source_xa); + + // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) + module.vec_znx_dft(&mut ci_dft, 0, self, i); + module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + + // use c[0] as buffer, which is overwritten later by the normalization step + module.vec_znx_big_normalize(log_base2k, self, 0, &ci_big, 0, scratch_2); + + // c0_tmp = -c[i] * s[i] (use c[0] as buffer) + module.vec_znx_sub_ab_inplace(&mut c0_big, 0, self, 0); + + // c[i] += m if col = i + if let Some((pt, col)) = pt { + if i == col { + module.vec_znx_add_inplace(self, i, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, self, i, scratch_2); + } + } + }); + } + + // c[0] += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, sigma * SIX_SIGMA); + + // c[0] += m if col = 0 + if let Some((pt, col)) = pt { + if col == 0 { + module.vec_znx_add_inplace(&mut c0_big, 0, pt, 0); + } + } + + // c[0] = norm(c[0]) + module.vec_znx_normalize(log_base2k, self, 0, &c0_big, 0, scratch_1); + } + + pub(crate) fn encrypt_pk_private( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.basek(), pk.basek()); + assert_eq!(self.n(), module.n()); + assert_eq!(pk.n(), module.n()); + assert_eq!(self.rank(), pk.rank()); + if let Some((pt, _)) = pt { + assert_eq!(pt.basek(), pk.basek()); + assert_eq!(pt.n(), module.n()); + } + } + + let log_base2k: usize = pk.basek(); + let size_pk: usize = pk.size(); + let cols: usize = self.rank() + 1; + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ + Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::ZERO => {} + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + // ct[i] = pk[i] * u + ei (+ m if col = i) + (0..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk); + // ci_dft = DFT(u) * DFT(pk[i]) + module.svp_apply(&mut ci_dft, 0, &u_dft, 0, pk, i); + + // ci_big = u * p[i] + let mut ci_big = module.vec_znx_idft_consume(ci_dft); + + // ci_big = u * pk[i] + e + ci_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, sigma * SIX_SIGMA); + + // ci_big = u * pk[i] + e + m (if col = i) + if let Some((pt, col)) = pt { + if col == i { + module.vec_znx_big_add_small_inplace(&mut ci_big, 0, pt, 0); + } + } + + // ct[i] = norm(ci_big) + module.vec_znx_big_normalize(log_base2k, self, i, &ci_big, 0, scratch_2); + }); + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let cols: usize = self.rank() + 1; + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + c0_big.zero(); + + { + (1..cols).for_each(|i| { + // ci_dft = DFT(a[i]) * DFT(s[i]) + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.vec_znx_dft(&mut ci_dft, 0, self, i); + module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1); + let ci_big = module.vec_znx_idft_consume(ci_dft); + + // c0_big += a[i] * s[i] + module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); + }); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_big_add_small_inplace(&mut c0_big, 0, self, 0); + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } +} diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs new file mode 100644 index 0000000..135a2dd --- /dev/null +++ b/core/src/glwe_ciphertext_fourier.rs @@ -0,0 +1,323 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, + ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size, +}; + +pub struct GLWECiphertextFourier { + pub data: VecZnxDft, + pub basek: usize, + pub k: usize, +} + +impl GLWECiphertextFourier, B> { + pub fn new(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)), + basek: basek, + k: k, + } + } +} + +impl Infos for GLWECiphertextFourier { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWECiphertextFourier { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl VecZnxDftToMut for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GLWECiphertextFourier, FFT64> { + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { + module.bytes_of_vec_znx(1, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, ct_size: usize) -> usize { + module.bytes_of_vec_znx(rank + 1, ct_size) + GLWECiphertext::encrypt_sk_scratch_space(module, ct_size) + } + + pub fn decrypt_scratch_space(module: &Module, ct_size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, ct_size) + | (module.bytes_of_vec_znx_big(1, ct_size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, ct_size) + } + + pub fn keyswitch_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, + ) -> usize { + module.bytes_of_vec_znx(out_rank + 1, out_size) + + GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size) + } + + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); + let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); + + res_dft + (vmp | (res_small + normalize)) + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank) + } +} + +impl GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, +{ + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size()); + let mut ct_idft = GLWECiphertext { + data: vec_znx_tmp, + basek: self.basek, + k: self.k, + }; + ct_idft.encrypt_zero_sk(module, sk_dft, source_xa, source_xe, sigma, scratch_1); + + ct_idft.dft(module, self); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GLWECiphertextFourier, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + let cols_out: usize = rhs.rank_out() + 1; + + // Space fr normalized VMP result outside of DFT domain + let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size()); + + let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_idft_data, + basek: lhs.basek, + k: lhs.k, + }; + + res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1); + + (0..cols_out).for_each(|i| { + module.vec_znx_dft(self, i, &res_idft, i); + }); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; + self.keyswitch(&module, &*self_ptr, rhs, scratch); + } + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GLWECiphertextFourier, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(rhs.rank(), lhs.rank()); + assert_eq!(rhs.rank(), self.rank()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + } + + let cols: usize = rhs.rank() + 1; + + // Space for VMP result in DFT domain and high precision + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); + + { + module.vmp_apply(&mut res_dft, lhs, rhs, scratch1); + } + + // VMP result in high precision + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + // Space for VMP result normalized + let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); + (0..cols).for_each(|i| { + module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); + module.vec_znx_dft(self, i, &res_small, i); + }); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + unsafe { + let self_ptr: *mut GLWECiphertextFourier = self as *mut GLWECiphertextFourier; + self.external_product(&module, &*self_ptr, rhs, scratch); + } + } +} + +impl GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToRef, +{ + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let cols = self.rank() + 1; + + let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + pt_big.zero(); + + { + (1..cols).for_each(|i| { + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.svp_apply(&mut ci_dft, 0, sk_dft, i - 1, self, i); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0); + }); + } + + { + let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size()); + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_idft(&mut c0_big, 0, self, 0, scratch_2); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0); + } + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), pt, 0, &mut pt_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } + + #[allow(dead_code)] + pub(crate) fn idft(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) + where + GLWECiphertext: VecZnxToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size); + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_idft(&mut res_big, 0, self, i, scratch1); + module.vec_znx_big_normalize(self.basek(), res, i, &res_big, 0, scratch1); + }); + } +} diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe_plaintext.rs new file mode 100644 index 0000000..4900fa0 --- /dev/null +++ b/core/src/glwe_plaintext.rs @@ -0,0 +1,53 @@ +use base2k::{Backend, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; + +use crate::{elem::Infos, utils::derive_size}; + +pub struct GLWEPlaintext { + pub data: VecZnx, + pub basek: usize, + pub k: usize, +} + +impl Infos for GLWEPlaintext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl VecZnxToMut for GLWEPlaintext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWEPlaintext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWEPlaintext> { + pub fn new(module: &Module, basek: usize, k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(basek, k)), + basek: basek, + k, + } + } +} diff --git a/core/src/keys.rs b/core/src/keys.rs new file mode 100644 index 0000000..8a4d5e1 --- /dev/null +++ b/core/src/keys.rs @@ -0,0 +1,247 @@ +use base2k::{ + Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, + ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos, + ZnxZero, +}; +use sampling::source::Source; + +use crate::{elem::Infos, glwe_ciphertext_fourier::GLWECiphertextFourier}; + +#[derive(Clone, Copy, Debug)] +pub enum SecretDistribution { + TernaryFixed(usize), // Ternary with fixed Hamming weight + TernaryProb(f64), // Ternary with probabilistic Hamming weight + ZERO, // Debug mod + NONE, +} + +pub struct SecretKey { + pub data: ScalarZnx, + pub dist: SecretDistribution, +} + +impl SecretKey> { + pub fn new(module: &Module, rank: usize) -> Self { + Self { + data: module.new_scalar_znx(rank), + dist: SecretDistribution::NONE, + } + } +} + +impl SecretKey { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl SecretKey +where + S: AsMut<[u8]> + AsRef<[u8]>, +{ + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_prob(i, prob, source); + }); + self.dist = SecretDistribution::TernaryProb(prob); + } + + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_hw(i, hw, source); + }); + self.dist = SecretDistribution::TernaryFixed(hw); + } + + pub fn fill_zero(&mut self) { + self.data.zero(); + self.dist = SecretDistribution::ZERO; + } +} + +impl ScalarZnxToMut for SecretKey +where + ScalarZnx: ScalarZnxToMut, +{ + fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl ScalarZnxToRef for SecretKey +where + ScalarZnx: ScalarZnxToRef, +{ + fn to_ref(&self) -> ScalarZnx<&[u8]> { + self.data.to_ref() + } +} + +pub struct SecretKeyFourier { + pub data: ScalarZnxDft, + pub dist: SecretDistribution, +} + +impl SecretKeyFourier { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + +impl SecretKeyFourier, B> { + pub fn new(module: &Module, rank: usize) -> Self { + Self { + data: module.new_scalar_znx_dft(rank), + dist: SecretDistribution::NONE, + } + } + + pub fn dft(&mut self, module: &Module, sk: &SecretKey) + where + SecretKeyFourier, B>: ScalarZnxDftToMut, + SecretKey: ScalarZnxToRef, + { + #[cfg(debug_assertions)] + { + match sk.dist { + SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), + _ => {} + } + + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.rank(), sk.rank()); + } + + (0..self.rank()).for_each(|i| { + module.svp_prepare(self, i, sk, i); + }); + self.dist = sk.dist; + } +} + +impl ScalarZnxDftToMut for SecretKeyFourier +where + ScalarZnxDft: ScalarZnxDftToMut, +{ + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl ScalarZnxDftToRef for SecretKeyFourier +where + ScalarZnxDft: ScalarZnxDftToRef, +{ + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +pub struct GLWEPublicKey { + pub data: GLWECiphertextFourier, + pub dist: SecretDistribution, +} + +impl GLWEPublicKey, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rank: usize) -> Self { + Self { + data: GLWECiphertextFourier::new(module, log_base2k, log_k, rank), + dist: SecretDistribution::NONE, + } + } +} + +impl Infos for GLWEPublicKey { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data.data + } + + fn basek(&self) -> usize { + self.data.basek + } + + fn k(&self) -> usize { + self.data.k + } +} + +impl GLWEPublicKey { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl VecZnxDftToMut for GLWEPublicKey +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for GLWEPublicKey +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GLWEPublicKey { + pub fn generate( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + ) where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + match sk_dft.dist { + SecretDistribution::NONE => panic!("invalid sk_dft: SecretDistribution::NONE"), + _ => {} + } + } + + // Its ok to allocate scratch space here since pk is usually generated only once. + let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_sk_scratch_space( + module, + self.rank(), + self.size(), + )); + self.data.encrypt_zero_sk( + module, + sk_dft, + source_xa, + source_xe, + sigma, + scratch.borrow(), + ); + self.dist = sk_dft.dist; + } +} diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs new file mode 100644 index 0000000..cade469 --- /dev/null +++ b/core/src/keyswitch_key.rs @@ -0,0 +1,385 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef, + ScalarZnxToRef, Scratch, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + keys::{SecretKey, SecretKeyFourier}, +}; + +pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); + +impl GLWESwitchingKey, FFT64> { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + GLWESwitchingKey(GGLWECiphertext::new( + module, basek, k, rows, rank_in, rank_out, + )) + } +} + +impl Infos for GLWESwitchingKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + self.0.inner() + } + + fn basek(&self) -> usize { + self.0.basek() + } + + fn k(&self) -> usize { + self.0.k() + } +} + +impl GLWESwitchingKey { + pub fn rank(&self) -> usize { + self.0.data.cols_out() - 1 + } + + pub fn rank_in(&self) -> usize { + self.0.data.cols_in() + } + + pub fn rank_out(&self) -> usize { + self.0.data.cols_out() - 1 + } +} + +impl MatZnxDftToMut for GLWESwitchingKey +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.0.data.to_mut() + } +} + +impl MatZnxDftToRef for GLWESwitchingKey +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.0.data.to_ref() + } +} + +impl GetRow for GLWESwitchingKey +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) + where + R: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for GLWESwitchingKey +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &R) + where + R: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl GLWESwitchingKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + } + + pub fn encrypt_pk_scratch_space(module: &Module, rank: usize, pk_size: usize) -> usize { + GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size) + } + + pub fn keyswitch_scratch_space( + module: &Module, + out_size: usize, + out_rank: usize, + in_size: usize, + in_rank: usize, + ksk_size: usize, + ) -> usize { + let tmp_in: usize = module.bytes_of_vec_znx_dft(in_rank + 1, in_size); + let tmp_out: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); + let ksk: usize = GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size); + tmp_in + tmp_out + ksk + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, out_size: usize, out_rank: usize, ksk_size: usize) -> usize { + let tmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size); + let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size); + tmp + ksk + } + + pub fn external_product_scratch_space( + module: &Module, + out_size: usize, + in_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size); + let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank); + tmp_in + tmp_out + ggsw + } + + pub fn external_product_inplace_scratch_space( + module: &Module, + out_size: usize, + ggsw_size: usize, + rank: usize, + ) -> usize { + let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); + let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank); + tmp + ggsw + } +} +impl GLWESwitchingKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_in: &SecretKey, + sk_out_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + self.0.encrypt_sk( + module, + &sk_in.data, + sk_out_dft, + source_xa, + source_xe, + sigma, + scratch, + ); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); + + let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_out_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.keyswitch(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.keyswitch_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank(), + "ksk_in output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); + + let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_in_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_out_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + lhs.get_row(module, row_j, col_i, &mut tmp_in); + tmp_out.external_product(module, &tmp_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_out); + }); + }); + + tmp_out.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_out); + }); + }); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_out(), + rhs.rank(), + "ksk_out output rank: {} != ggsw rank: {}", + self.rank_out(), + rhs.rank() + ); + } + + let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); + + let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_data, + basek: self.basek(), + k: self.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + self.get_row(module, row_j, col_i, &mut tmp); + tmp.external_product_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp); + }); + }); + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs new file mode 100644 index 0000000..74ed7ef --- /dev/null +++ b/core/src/lib.rs @@ -0,0 +1,15 @@ +pub mod automorphism; +pub mod elem; +pub mod gglwe_ciphertext; +pub mod ggsw_ciphertext; +pub mod glwe_ciphertext; +pub mod glwe_ciphertext_fourier; +pub mod glwe_plaintext; +pub mod keys; +pub mod keyswitch_key; +pub mod tensor_key; +#[cfg(test)] +mod test_fft64; +mod utils; + +pub(crate) const SIX_SIGMA: f64 = 6.0; diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs new file mode 100644 index 0000000..158274d --- /dev/null +++ b/core/src/tensor_key.rs @@ -0,0 +1,130 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnxDftOps, VecZnxDftToRef, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, +}; + +pub struct TensorKey { + pub(crate) keys: Vec>, +} + +impl TensorKey, FFT64> { + pub fn new(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + let mut keys: Vec, FFT64>> = Vec::new(); + let pairs: usize = ((rank + 1) * rank) >> 1; + (0..pairs).for_each(|_| { + keys.push(GLWESwitchingKey::new(module, basek, k, rows, 1, rank)); + }); + Self { keys: keys } + } +} + +impl Infos for TensorKey { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.keys[0].inner() + } + + fn basek(&self) -> usize { + self.keys[0].basek() + } + + fn k(&self) -> usize { + self.keys[0].k() + } +} + +impl TensorKey { + pub fn rank(&self) -> usize { + self.keys[0].rank() + } + + pub fn rank_in(&self) -> usize { + self.keys[0].rank_in() + } + + pub fn rank_out(&self) -> usize { + self.keys[0].rank_out() + } +} + +impl TensorKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + module.bytes_of_scalar_znx_dft(1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, rank, size) + } +} + +impl TensorKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: VecZnxDftToRef + ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let rank: usize = self.rank(); + + (0..rank).for_each(|i| { + (i..rank).for_each(|j| { + let (mut sk_ij_dft, scratch1) = scratch.tmp_scalar_znx_dft(module, 1); + module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j); + let sk_ij: ScalarZnx<&mut [u8]> = module + .vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft()) + .to_vec_znx_small() + .to_scalar_znx(); + let sk_ij: SecretKey<&mut [u8]> = SecretKey { + data: sk_ij, + dist: sk_dft.dist, + }; + + self.at_mut(i, j).encrypt_sk( + module, &sk_ij, sk_dft, source_xa, source_xe, sigma, scratch1, + ); + }); + }) + } + + // Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &mut self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} + +impl TensorKey +where + MatZnxDft: MatZnxDftToRef, +{ + // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) + pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank(); + &self.keys[i * rank + j - (i * (i + 1) / 2)] + } +} diff --git a/core/src/test_fft64/automorphism_key.rs b/core/src/test_fft64/automorphism_key.rs new file mode 100644 index 0000000..ea63550 --- /dev/null +++ b/core/src/test_fft64/automorphism_key.rs @@ -0,0 +1,216 @@ +use base2k::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps}; +use sampling::source::Source; + +use crate::{ + automorphism::AutomorphismKey, + elem::{GetRow, Infos}, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + test_fft64::gglwe::log2_std_noise_gglwe_product, +}; + +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(-1, 5, 12, 12, 60, 3.2, rank); + }); +} + +#[test] +fn automorphism_inplace() { + (1..4).for_each(|rank| { + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank); + }); +} + +fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; + + let mut auto_key_in: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + let mut auto_key_out: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + let mut auto_key_apply: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key_out.size()) + | AutomorphismKey::automorphism_scratch_space( + &module, + auto_key_out.size(), + auto_key_in.size(), + auto_key_apply.size(), + rank, + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + // gglwe_{s1}(s0) = s0 -> s1 + auto_key_in.encrypt_sk( + &module, + p0, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + auto_key_apply.encrypt_sk( + &module, + p1, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key_out.automorphism(&module, &auto_key_in, &auto_key_apply, scratch.borrow()); + + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + let mut sk_auto: SecretKey> = SecretKey::new(&module, rank); + sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + (0..rank).for_each(|i| { + module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i); + }); + + let mut sk_auto_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_auto_dft.dft(&module, &sk_auto); + + (0..auto_key_out.rank_in()).for_each(|col_i| { + (0..auto_key_out.rows()).for_each(|row_i| { + auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + + ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ksk, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} + +fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; + + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + let mut auto_key_apply: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key.size()) + | AutomorphismKey::automorphism_inplace_scratch_space(&module, auto_key.size(), auto_key_apply.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + // gglwe_{s1}(s0) = s0 -> s1 + auto_key.encrypt_sk( + &module, + p0, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + auto_key_apply.encrypt_sk( + &module, + p1, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow()); + + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + let mut sk_auto: SecretKey> = SecretKey::new(&module, rank); + sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk + (0..rank).for_each(|i| { + module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i); + }); + + let mut sk_auto_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_auto_dft.dft(&module, &sk_auto); + + (0..auto_key.rank_in()).for_each(|col_i| { + (0..auto_key.rows()).for_each(|row_i| { + auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + + ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ksk, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs new file mode 100644 index 0000000..d497dbf --- /dev/null +++ b/core/src/test_fft64/gglwe.rs @@ -0,0 +1,630 @@ +use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToMut, ScratchOwned, Stats, VecZnxOps, ZnxViewMut}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + test_fft64::ggsw::noise_ggsw_product, +}; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test encrypt_sk rank_in rank_out: {} {}", rank_in, rank_out); + test_encrypt_sk(12, 8, 54, 3.2, rank_in, rank_out); + }); + }); +} + +#[test] +fn key_switch() { + (1..4).for_each(|rank_in_s0s1| { + (1..4).for_each(|rank_out_s0s1| { + (1..4).for_each(|rank_out_s1s2| { + println!( + "test key_switch : ({},{},{})", + rank_in_s0s1, rank_out_s0s1, rank_out_s1s2 + ); + test_key_switch(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1, rank_out_s1s2); + }) + }); + }); +} + +#[test] +fn key_switch_inplace() { + (1..4).for_each(|rank_in_s0s1| { + (1..4).for_each(|rank_out_s0s1| { + println!( + "test key_switch_inplace : ({},{})", + rank_in_s0s1, rank_out_s0s1 + ); + test_key_switch_inplace(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1); + }); + }); +} + +#[test] +fn external_product() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test external_product rank: {} {}", rank_in, rank_out); + test_external_product(12, 12, 60, 3.2, rank_in, rank_out); + }); + }); +} + +#[test] +fn external_product_inplace() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!( + "test external_product_inplace rank: {} {}", + rank_in, rank_out + ); + test_external_product_inplace(12, 12, 60, 3.2, rank_in, rank_out); + }); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in: usize, rank_out: usize) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ksk.size()), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out); + + (0..ksk.rank_in()).for_each(|col_i| { + (0..ksk.rows()).for_each(|row_i| { + ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); + let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + }); +} + +fn test_key_switch( + log_n: usize, + basek: usize, + k_ksk: usize, + sigma: f64, + rank_in_s0s1: usize, + rank_out_s0s1: usize, + rank_out_s1s2: usize, +) { + let module: Module = Module::::new(1 << log_n); + let rows = (k_ksk + basek - 1) / basek; + + let mut ct_gglwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1); + let mut ct_gglwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s1s2); + let mut ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s1s2); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in_s0s1 | rank_out_s0s1, ct_gglwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_s0s2.size()) + | GLWESwitchingKey::keyswitch_scratch_space( + &module, + ct_gglwe_s0s2.size(), + ct_gglwe_s0s2.rank(), + ct_gglwe_s0s1.size(), + ct_gglwe_s0s1.rank(), + ct_gglwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module, rank_in_s0s1); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in_s0s1); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module, rank_out_s0s1); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module, rank_out_s1s2); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out_s1s2); + sk2_dft.dft(&module, &sk2); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_s0s1.encrypt_sk( + &module, + &sk0, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + ct_gglwe_s1s2.encrypt_sk( + &module, + &sk1, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow()); + + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out_s1s2); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { + (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { + ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_out_s0s1 as f64, + k_ksk, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} + +fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in_s0s1: usize, rank_out_s0s1: usize) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ksk + basek - 1) / basek; + + let mut ct_gglwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1); + let mut ct_gglwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s0s1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out_s0s1, ct_gglwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_s0s1.size()) + | GLWESwitchingKey::keyswitch_inplace_scratch_space( + &module, + ct_gglwe_s0s1.size(), + ct_gglwe_s0s1.rank(), + ct_gglwe_s1s2.size(), + ), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module, rank_in_s0s1); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in_s0s1); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module, rank_out_s0s1); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1); + sk1_dft.dft(&module, &sk1); + + let mut sk2: SecretKey> = SecretKey::new(&module, rank_out_s0s1); + sk2.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1); + sk2_dft.dft(&module, &sk2); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_s0s1.encrypt_sk( + &module, + &sk0, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s2}(s1) -> s1 -> s2 + ct_gglwe_s1s2.encrypt_sk( + &module, + &sk1, + &sk2_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) + ct_gglwe_s0s1.keyswitch_inplace(&module, &ct_gglwe_s1s2, scratch.borrow()); + + let ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = ct_gglwe_s0s1; + + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out_s0s1); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ksk); + + (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { + (0..ct_gglwe_s0s2.rows()).for_each(|row_i| { + ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_out_s0s1 as f64, + k_ksk, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} + +fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k + basek - 1) / basek; + + let mut ct_gglwe_in: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out); + let mut ct_gglwe_out: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank_out); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_out.size()) + | GLWESwitchingKey::external_product_scratch_space( + &module, + ct_gglwe_out.size(), + ct_gglwe_in.size(), + ct_rgsw.size(), + rank_out, + ) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()), + ); + + let r: usize = 1; + + pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe_in.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) + ct_gglwe_out.external_product(&module, &ct_gglwe_in, &ct_rgsw, scratch.borrow()); + + scratch = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_out.size()) + | GLWESwitchingKey::external_product_scratch_space( + &module, + ct_gglwe_out.size(), + ct_gglwe_in.size(), + ct_rgsw.size(), + rank_out, + ) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()), + ); + + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + + (0..rank_in).for_each(|i| { + module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} + }); + + (0..rank_in).for_each(|col_i| { + (0..ct_gglwe_out.rows()).for_each(|row_i| { + ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_out as f64, + k, + k, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} + +fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k + basek - 1) / basek; + + let mut ct_gglwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank_out); + + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe.size()) + | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_gglwe.size(), ct_rgsw.size(), rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()), + ); + + let r: usize = 1; + + pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r} + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + // gglwe_{s1}(s0) = s0 -> s1 + ct_gglwe.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + // gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k) + ct_gglwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); + + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + + (0..rank_in).for_each(|i| { + module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r} + }); + + (0..rank_in).for_each(|col_i| { + (0..ct_gglwe.rows()).for_each(|row_i| { + ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft); + ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_out as f64, + k, + k, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + }); + }); +} + +pub(crate) fn var_noise_gglwe_product( + n: f64, + basek: usize, + var_xs: f64, + var_msg: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank_in: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let a_logq: usize = a_logq.min(b_logq); + let a_cols: usize = (a_logq + basek - 1) / basek; + + let b_scale = 2.0f64.powi(b_logq as i32); + let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); + + let base: f64 = (1 << (basek)) as f64; + let var_base: f64 = base * base / 12f64; + + // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) + // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs + let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a_err * a_scale * a_scale * n; + noise *= rank_in; + noise /= b_scale * b_scale; + noise +} + +pub(crate) fn log2_std_noise_gglwe_product( + n: f64, + basek: usize, + var_xs: f64, + var_msg: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank_in: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let mut noise: f64 = var_noise_gglwe_product( + n, + basek, + var_xs, + var_msg, + var_a_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank_in, + a_logq, + b_logq, + ); + noise = noise.sqrt(); + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} diff --git a/core/src/test_fft64/ggsw.rs b/core/src/test_fft64/ggsw.rs new file mode 100644 index 0000000..f02bd87 --- /dev/null +++ b/core/src/test_fft64/ggsw.rs @@ -0,0 +1,934 @@ +use base2k::{ + FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScalarZnxOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, + VecZnxBigOps, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + automorphism::AutomorphismKey, + elem::{GetRow, Infos}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + tensor_key::TensorKey, +}; + +use super::gglwe::var_noise_gglwe_product; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(11, 8, 54, 3.2, rank); + }); +} + +#[test] +fn keyswitch() { + (1..4).for_each(|rank| { + println!("test keyswitch rank: {}", rank); + test_keyswitch(12, 15, 60, rank, 3.2); + }); +} + +#[test] +fn keyswitch_inplace() { + (1..4).for_each(|rank| { + println!("test keyswitch_inplace rank: {}", rank); + test_keyswitch_inplace(12, 15, 60, rank, 3.2); + }); +} + +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(-5, 12, 15, 60, rank, 3.2); + }); +} + +#[test] +fn automorphism_inplace() { + (1..4).for_each(|rank| { + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(-5, 12, 15, 60, rank, 3.2); + }); +} + +#[test] +fn external_product() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product(12, 12, 60, rank, 3.2); + }); +} + +#[test] +fn external_product_inplace() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product_inplace(12, 15, 60, rank, 3.2); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ggsw + basek - 1) / basek; + + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.rank() + 1).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let std_pt: f64 = pt_have.data.std(0, basek) * (k_ggsw as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + + pt_want.data.zero(); + }); + }); +} + +fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tsk: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GGSWCiphertext::keyswitch_scratch_space( + &module, + ct_out.size(), + ct_in.size(), + ksk.size(), + tsk.size(), + rank, + ), + ); + + let var_xs: f64 = 0.5; + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tsk.encrypt_sk( + &module, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct_in.encrypt_sk( + &module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); + + (0..ct_out.rank() + 1).for_each(|col_j| { + (0..ct_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_ggsw_keyswitch( + module.n() as f64, + basek, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + println!("{} {}", noise_have, noise_want); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} + +fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tsk: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()) + | GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct.size(), ksk.size(), tsk.size(), rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tsk.encrypt_sk( + &module, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.keyswitch_inplace(&module, &ksk, &tsk, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.rank() + 1).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_ggsw_keyswitch( + module.n() as f64, + basek, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + println!("{} {}", noise_have, noise_want); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} + +pub(crate) fn noise_ggsw_keyswitch( + n: f64, + basek: usize, + col: usize, + var_xs: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let var_si_x_sj: f64 = n * var_xs * var_xs; + + // Initial KS for col = 0 + let mut noise: f64 = var_noise_gglwe_product( + n, + basek, + var_xs, + var_xs, + var_a_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank, + a_logq, + b_logq, + ); + + // Other GGSW reconstruction for col > 0 + if col > 0 { + noise += var_noise_gglwe_product( + n, + basek, + var_xs, + var_si_x_sj, + var_a_err + 1f64 / 12.0, + var_gct_err_lhs, + var_gct_err_rhs, + rank, + a_logq, + b_logq, + ); + noise += n * noise * var_xs * 0.5; + } + + noise = noise.sqrt(); + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} + +fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut ct_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tensor_key: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size()) + | AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size()) + | GGSWCiphertext::automorphism_scratch_space( + &module, + ct_out.size(), + ct_in.size(), + auto_key.size(), + tensor_key.size(), + rank, + ), + ); + + let var_xs: f64 = 0.5; + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + auto_key.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tensor_key.encrypt_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct_in.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.automorphism(&module, &ct_in, &auto_key, &tensor_key, scratch.borrow()); + + module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_out.size()); + + (0..ct_out.rank() + 1).for_each(|col_j| { + (0..ct_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_ggsw_keyswitch( + module.n() as f64, + basek, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} + +fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k + basek - 1) / basek; + + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank); + let mut tensor_key: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + let mut auto_key: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()) + | AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size()) + | TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size()) + | GGSWCiphertext::automorphism_inplace_scratch_space(&module, ct.size(), auto_key.size(), tensor_key.size(), rank), + ); + + let var_xs: f64 = 0.5; + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(var_xs, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + auto_key.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + tensor_key.encrypt_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); + + ct.encrypt_sk( + &module, + &pt_scalar, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.automorphism_inplace(&module, &auto_key, &tensor_key, scratch.borrow()); + + module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); + + (0..ct.rank() + 1).for_each(|col_j| { + (0..ct.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); + + // mul with sk[col_j-1] + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + + ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = noise_ggsw_keyswitch( + module.n() as f64, + basek, + col_j, + var_xs, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k, + k, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} + +fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ggsw + basek - 1) / basek; + + let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size()) + | GGSWCiphertext::external_product_scratch_space( + &module, + ct_ggsw_lhs_out.size(), + ct_ggsw_lhs_in.size(), + ct_ggsw_rhs.size(), + rank, + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_ggsw_rhs.encrypt_sk( + &module, + &pt_ggsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs_in.encrypt_sk( + &module, + &pt_ggsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size()); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); + + (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { + (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); + + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ggsw, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} + +fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ggsw + basek - 1) / basek; + + let mut ct_ggsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_ggsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut pt_ggsw_lhs: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_ggsw_rhs: ScalarZnx> = module.new_scalar_znx(1); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs); + + let k: usize = 1; + + pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs.size()) + | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_ggsw_lhs.size(), ct_ggsw_rhs.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_ggsw_rhs.encrypt_sk( + &module, + &pt_ggsw_rhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs.encrypt_sk( + &module, + &pt_ggsw_lhs, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow()); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size()); + let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size()); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ggsw); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0); + + (0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| { + (0..ct_ggsw_lhs.rows()).for_each(|row_i| { + module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); + + if col_j > 0 { + module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); + module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); + module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); + module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); + } + + ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); + + let noise_have: f64 = pt.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ggsw, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "have: {} want: {}", + noise_have, + noise_want + ); + + pt_want.data.zero(); + }); + }); +} +pub(crate) fn noise_ggsw_product( + n: f64, + basek: usize, + var_xs: f64, + var_msg: f64, + var_a0_err: f64, + var_a1_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank: f64, + a_logq: usize, + b_logq: usize, +) -> f64 { + let a_logq: usize = a_logq.min(b_logq); + let a_cols: usize = (a_logq + basek - 1) / basek; + + let b_scale = 2.0f64.powi(b_logq as i32); + let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); + + let base: f64 = (1 << (basek)) as f64; + let var_base: f64 = base * base / 12f64; + + // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) + // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs + let mut noise: f64 = (rank + 1.0) * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a0_err * a_scale * a_scale * n; + noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs * rank; + noise = noise.sqrt(); + noise /= b_scale; + noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] +} diff --git a/core/src/test_fft64/glwe.rs b/core/src/test_fft64/glwe.rs new file mode 100644 index 0000000..0f7fcc1 --- /dev/null +++ b/core/src/test_fft64/glwe.rs @@ -0,0 +1,805 @@ +use base2k::{ + Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, + ZnxViewMut, ZnxZero, +}; +use itertools::izip; +use sampling::source::Source; + +use crate::{ + automorphism::AutomorphismKey, + elem::Infos, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{GLWEPublicKey, SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product}, +}; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(11, 8, 54, 30, 3.2, rank); + }); +} + +#[test] +fn encrypt_zero_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_zero_sk rank: {}", rank); + test_encrypt_zero_sk(11, 8, 64, 3.2, rank); + }); +} + +#[test] +fn encrypt_pk() { + (1..4).for_each(|rank| { + println!("test encrypt_pk rank: {}", rank); + test_encrypt_pk(11, 8, 64, 64, 3.2, rank) + }); +} + +#[test] +fn keyswitch() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out); + test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2); + }); + }); +} + +#[test] +fn keyswitch_inplace() { + (1..4).for_each(|rank| { + println!("test keyswitch_inplace rank: {}", rank); + test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2); + }); +} + +#[test] +fn external_product() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product(12, 12, 60, 45, 60, rank, 3.2); + }); +} + +#[test] +fn external_product_inplace() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product_inplace(12, 15, 60, 60, rank, 3.2); + }); +} + +#[test] +fn automorphism_inplace() { + (1..4).for_each(|rank| { + println!("test automorphism_inplace rank: {}", rank); + test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2); + }); +} + +#[test] +fn automorphism() { + (1..4).for_each(|rank| { + println!("test automorphism rank: {}", rank); + test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_pt); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); + + pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10); + + ct.encrypt_sk( + &module, + &pt, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + pt.data.zero(); + + ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + let mut data_have: Vec = vec![0i64; module.n()]; + + pt.data + .decode_vec_i64(0, basek, pt.size() * basek, &mut data_have); + + // TODO: properly assert the decryption noise through std(dec(ct) - pt) + let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64; + izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { + let b_scaled = (*b as f64) / scale; + assert!( + (*a as f64 - b_scaled).abs() < 0.1, + "{} {}", + *a as f64, + b_scaled + ) + }); +} + +fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWECiphertextFourier::decrypt_scratch_space(&module, ct_dft.size()) + | GLWECiphertextFourier::encrypt_sk_scratch_space(&module, rank, ct_dft.size()), + ); + + ct_dft.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2); +} + +fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, basek, k_pk, rank); + pk.generate(&module, &sk_dft, &mut source_xa, &mut source_xe, sigma); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) + | GLWECiphertext::encrypt_pk_scratch_space(&module, pk.size()), + ); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0); + + pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10); + + ct.encrypt_pk( + &module, + &pt_want, + &pk, + &mut source_xu, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); + + let noise_have: f64 = pt_want.data.std(0, basek).log2(); + let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); + + assert!( + (noise_have - noise_want).abs() < 0.2, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_keyswitch( + log_n: usize, + basek: usize, + k_keyswitch: usize, + k_ct_in: usize, + k_ct_out: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_keyswitch, rows, rank_in, rank_out); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank_in); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in, ksk.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size()) + | GLWECiphertext::keyswitch_scratch_space( + &module, + ct_out.size(), + rank_out, + ct_in.size(), + rank_in, + ksk.size(), + ), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow()); + + ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_in as f64, + k_ct_in, + k_keyswitch, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; + + let mut ct_grlwe: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), rank, ct_grlwe.size()), + ); + + let mut sk0: SecretKey> = SecretKey::new(&module, rank); + sk0.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk0_dft.dft(&module, &sk0); + + let mut sk1: SecretKey> = SecretKey::new(&module, rank); + sk1.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk1_dft.dft(&module, &sk1); + + ct_grlwe.encrypt_sk( + &module, + &sk0, + &sk1_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + &pt_want, + &sk0_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rlwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_automorphism( + log_n: usize, + basek: usize, + p: i64, + k_autokey: usize, + k_ct_in: usize, + k_ct_out: usize, + rank: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_autokey, rows, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size()) + | GLWECiphertext::automorphism_scratch_space(&module, ct_out.size(), rank, ct_in.size(), autokey.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + autokey.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); + ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + module.vec_znx_automorphism_inplace(p, &mut pt_want, 0); + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow()); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + println!("{}", noise_have); + + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct_in, + k_autokey, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; + + let mut autokey: AutomorphismKey, FFT64> = AutomorphismKey::new(&module, basek, k_autokey, rows, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertext::automorphism_inplace_scratch_space(&module, ct.size(), rank, autokey.size()), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + autokey.encrypt_sk( + &module, + p, + &sk, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.automorphism_inplace(&module, &autokey, scratch.borrow()); + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + module.vec_znx_automorphism_inplace(p, &mut pt_want, 0); + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow()); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_autokey, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::external_product_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + rank, + ), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rlwe_out.external_product(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); + + ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct_in, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rlwe.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_rlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); + + ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs new file mode 100644 index 0000000..d8bd11c --- /dev/null +++ b/core/src/test_fft64/glwe_fourier.rs @@ -0,0 +1,445 @@ +use crate::{ + elem::Infos, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product}, +}; +use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; +use sampling::source::Source; + +#[test] +fn keyswitch() { + (1..4).for_each(|rank_in| { + (1..4).for_each(|rank_out| { + println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out); + test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2); + }); + }); +} + +#[test] +fn keyswitch_inplace() { + (1..4).for_each(|rank| { + println!("test keyswitch_inplace rank: {}", rank); + test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2); + }); +} + +#[test] +fn external_product() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product(12, 12, 60, 45, 60, rank, 3.2); + }); +} + +#[test] +fn external_product_inplace() { + (1..4).for_each(|rank| { + println!("test external_product rank: {}", rank); + test_external_product_inplace(12, 15, 60, 60, rank, 3.2); + }); +} + +fn test_keyswitch( + log_n: usize, + basek: usize, + k_ksk: usize, + k_ct_in: usize, + k_ct_out: usize, + rank_in: usize, + rank_out: usize, + sigma: f64, +) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in, rank_out); + let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank_in); + let mut ct_glwe_dft_in: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_in, rank_in); + let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank_out); + let mut ct_glwe_dft_out: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, basek, k_ct_out, rank_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_glwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe_in.size()) + | GLWECiphertextFourier::keyswitch_scratch_space( + &module, + ct_glwe_out.size(), + rank_out, + ct_glwe_in.size(), + rank_in, + ksk.size(), + ), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank_in); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_in); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank_out); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank_out); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_in.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe_in.dft(&module, &mut ct_glwe_dft_in); + ct_glwe_dft_out.keyswitch(&module, &ct_glwe_dft_in, &ksk, scratch.borrow()); + ct_glwe_dft_out.idft(&module, &mut ct_glwe_out, scratch.borrow()); + + ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank_in as f64, + k_ct_in, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; + + let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank); + let mut ct_glwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + let mut scratch: ScratchOwned = ScratchOwned::new( + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_glwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe.size()) + | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ksk.size(), rank), + ); + + let mut sk_in: SecretKey> = SecretKey::new(&module, rank); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_in_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_in_dft.dft(&module, &sk_in); + + let mut sk_out: SecretKey> = SecretKey::new(&module, rank); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_out_dft.dft(&module, &sk_out); + + ksk.encrypt_sk( + &module, + &sk_in, + &sk_out_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.encrypt_sk( + &module, + &pt_want, + &sk_in_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_glwe.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.keyswitch_inplace(&module, &ksk, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct_glwe, scratch.borrow()); + + ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + let noise_want: f64 = log2_std_noise_gglwe_product( + module.n() as f64, + basek, + 0.5, + 0.5, + 0f64, + sigma * sigma, + 0f64, + rank as f64, + k_ct, + k_ksk, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k_ct_in + basek - 1) / basek; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_in, rank); + let mut ct_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct_out, rank); + let mut ct_in_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_in, rank); + let mut ct_out_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_out, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct_out); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size()) + | GLWECiphertextFourier::external_product_scratch_space(&module, ct_out.size(), ct_in.size(), ct_rgsw.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct_in.dft(&module, &mut ct_in_dft); + ct_out_dft.external_product(&module, &ct_in_dft, &ct_rgsw, scratch.borrow()); + ct_out_dft.idft(&module, &mut ct_out, scratch.borrow()); + + ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct_in, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} + +fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) { + let module: Module = Module::::new(1 << log_n); + let rows: usize = (k_ct + basek - 1) / basek; + + let mut ct_ggsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_ct, rank); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank); + let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + // Random input plaintext + pt_want + .data + .fill_uniform(basek, 0, pt_want.size(), &mut source_xa); + + pt_want.to_mut().at_mut(0, 0)[1] = 1; + + let k: usize = 1; + + pt_rgsw.raw_mut()[k] = 1; // X^{k} + + let mut scratch: ScratchOwned = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertextFourier::external_product_inplace_scratch_space(&module, ct.size(), ct_ggsw.size(), rank), + ); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + ct_ggsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.encrypt_sk( + &module, + &pt_want, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + ct.dft(&module, &mut ct_rlwe_dft); + ct_rlwe_dft.external_product_inplace(&module, &ct_ggsw, scratch.borrow()); + ct_rlwe_dft.idft(&module, &mut ct, scratch.borrow()); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); + + module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); + + let noise_have: f64 = pt_have.data.std(0, basek).log2(); + + let var_gct_err_lhs: f64 = sigma * sigma; + let var_gct_err_rhs: f64 = 0f64; + + let var_msg: f64 = 1f64 / module.n() as f64; // X^{k} + let var_a0_err: f64 = sigma * sigma; + let var_a1_err: f64 = 1f64 / 12f64; + + let noise_want: f64 = noise_ggsw_product( + module.n() as f64, + basek, + 0.5, + var_msg, + var_a0_err, + var_a1_err, + var_gct_err_lhs, + var_gct_err_rhs, + rank as f64, + k_ct, + k_ggsw, + ); + + assert!( + (noise_have - noise_want).abs() <= 0.1, + "{} {}", + noise_have, + noise_want + ); +} diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs new file mode 100644 index 0000000..fb2129e --- /dev/null +++ b/core/src/test_fft64/mod.rs @@ -0,0 +1,6 @@ +mod automorphism_key; +mod gglwe; +mod ggsw; +mod glwe; +mod glwe_fourier; +mod tensor_key; diff --git a/core/src/test_fft64/tensor_key.rs b/core/src/test_fft64/tensor_key.rs new file mode 100644 index 0000000..920341b --- /dev/null +++ b/core/src/test_fft64/tensor_key.rs @@ -0,0 +1,77 @@ +use base2k::{FFT64, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos}, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{SecretKey, SecretKeyFourier}, + tensor_key::TensorKey, +}; + +#[test] +fn encrypt_sk() { + (1..4).for_each(|rank| { + println!("test encrypt_sk rank: {}", rank); + test_encrypt_sk(12, 16, 54, 3.2, rank); + }); +} + +fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) { + let module: Module = Module::::new(1 << log_n); + + let rows: usize = (k + basek - 1) / basek; + + let mut tensor_key: TensorKey, FFT64> = TensorKey::new(&module, basek, k, rows, rank); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space( + &module, + rank, + tensor_key.size(), + )); + + let mut sk: SecretKey> = SecretKey::new(&module, rank); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); + sk_dft.dft(&module, &sk); + + tensor_key.encrypt_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + scratch.borrow(), + ); + + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, basek, k); + + (0..rank).for_each(|i| { + (0..rank).for_each(|j| { + let mut sk_ij_dft: base2k::ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); + module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j); + let sk_ij: ScalarZnx> = module + .vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft()) + .to_vec_znx_small() + .to_scalar_znx(); + + (0..tensor_key.rank_in()).for_each(|col_i| { + (0..tensor_key.rows()).for_each(|row_i| { + tensor_key + .at(i, j) + .get_row(&module, row_i, col_i, &mut ct_glwe_fourier); + ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i); + let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2(); + assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); + }); + }); + }) + }) +} diff --git a/core/src/utils.rs b/core/src/utils.rs new file mode 100644 index 0000000..c3bc5d5 --- /dev/null +++ b/core/src/utils.rs @@ -0,0 +1,3 @@ +pub(crate) fn derive_size(basek: usize, k: usize) -> usize { + (k + basek - 1) / basek +} diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs deleted file mode 100644 index fdd2240..0000000 --- a/rlwe/benches/gadget_product.rs +++ /dev/null @@ -1,139 +0,0 @@ -use base2k::{BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8}; -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use rlwe::{ - ciphertext::{Ciphertext, new_gadget_ciphertext}, - elem::ElemCommon, - encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}, - gadget_product::{gadget_product_core, gadget_product_core_tmp_bytes}, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, -}; -use sampling::source::Source; - -fn bench_gadget_product_inplace(c: &mut Criterion) { - fn runner<'a>( - module: &'a Module, - res_dft_0: &'a mut VecZnxDft, - res_dft_1: &'a mut VecZnxDft, - a: &'a VecZnx, - b: &'a Ciphertext, - b_cols: usize, - tmp_bytes: &'a mut [u8], - ) -> Box { - Box::new(move || { - gadget_product_core(module, res_dft_0, res_dft_1, a, b, b_cols, tmp_bytes); - }) - } - - let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("gadget_product_inplace"); - - for log_n in 10..11 { - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: log_n, - log_q: 32, - log_p: 0, - log_base2k: 16, - log_scale: 20, - xe: 3.2, - xs: 128, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let mut tmp_bytes: Vec = alloc_aligned_u8( - params.encrypt_rlwe_sk_tmp_bytes(params.log_q()) - | gadget_product_core_tmp_bytes( - params.module(), - params.log_base2k(), - params.log_q(), - params.log_q(), - params.cols_q(), - params.log_qp(), - ) - | encrypt_grlwe_sk_tmp_bytes( - params.module(), - params.log_base2k(), - params.cols_qp(), - params.log_qp(), - ), - ); - - let mut source: Source = Source::new([3; 32]); - - let mut sk0: SecretKey = SecretKey::new(params.module()); - let mut sk1: SecretKey = SecretKey::new(params.module()); - sk0.fill_ternary_hw(params.xs(), &mut source); - sk1.fill_ternary_hw(params.xs(), &mut source); - - let mut source_xe: Source = Source::new([4; 32]); - let mut source_xa: Source = Source::new([5; 32]); - - let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); - - let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); - - let mut gadget_ct: Ciphertext = new_gadget_ciphertext( - params.module(), - params.log_base2k(), - params.cols_q(), - params.log_qp(), - ); - - encrypt_grlwe_sk( - params.module(), - &mut gadget_ct, - &sk0.0, - &sk1_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); - - params.encrypt_rlwe_sk( - &mut ct, - None, - &sk0_svp_ppol, - &mut source_xa, - &mut source_xe, - &mut tmp_bytes, - ); - - let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols()); - let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols()); - - let mut a: VecZnx = params.module().new_vec_znx(0, params.cols_q()); - params - .module() - .fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa); - - let b_cols: usize = gadget_ct.cols(); - - let runners: [(String, Box); 1] = [(format!("gadget_product"), { - runner( - params.module(), - &mut res_dft_0, - &mut res_dft_1, - &mut a, - &gadget_ct, - b_cols, - &mut tmp_bytes, - ) - })]; - - for (name, mut runner) in runners { - let id: BenchmarkId = BenchmarkId::new(name, format!("n={}", 1 << log_n)); - b.bench_with_input(id, &(), |b: &mut criterion::Bencher<'_>, _| { - b.iter(&mut runner) - }); - } - } -} - -criterion_group!(benches, bench_gadget_product_inplace); -criterion_main!(benches); diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs deleted file mode 100644 index b9d66cd..0000000 --- a/rlwe/examples/encryption.rs +++ /dev/null @@ -1,76 +0,0 @@ -use base2k::{Encoding, SvpPPolOps, VecZnx, alloc_aligned}; -use rlwe::{ - ciphertext::Ciphertext, - elem::ElemCommon, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, - plaintext::Plaintext, -}; -use sampling::source::Source; - -fn main() { - let params_lit: ParametersLiteral = ParametersLiteral { - backend: base2k::BACKEND::FFT64, - log_n: 10, - log_q: 54, - log_p: 0, - log_base2k: 17, - log_scale: 20, - xe: 3.2, - xs: 128, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let mut tmp_bytes: Vec = - alloc_aligned(params.decrypt_rlwe_tmp_byte(params.log_q()) | params.encrypt_rlwe_sk_tmp_bytes(params.log_q())); - - let mut source: Source = Source::new([0; 32]); - let mut sk: SecretKey = SecretKey::new(params.module()); - sk.fill_ternary_hw(params.xs(), &mut source); - - let mut want = vec![i64::default(); params.n()]; - - want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - - let mut pt: Plaintext = params.new_plaintext(params.log_q()); - - let log_base2k = pt.log_base2k(); - - let log_k: usize = params.log_q() - 20; - - pt.0.value[0].encode_vec_i64(0, log_base2k, log_k, &want, 32); - pt.0.value[0].normalize(log_base2k, &mut tmp_bytes); - - println!("log_k: {}", log_k); - pt.0.value[0].print(0, pt.cols(), 16); - println!(); - - let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); - - let mut source_xe: Source = Source::new([1; 32]); - let mut source_xa: Source = Source::new([2; 32]); - - let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); - - params.encrypt_rlwe_sk( - &mut ct, - Some(&pt), - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - &mut tmp_bytes, - ); - - params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes); - pt.0.value[0].print(0, pt.cols(), 16); - - let mut have = vec![i64::default(); params.n()]; - - println!("pt: {}", log_k); - pt.0.value[0].decode_vec_i64(0, pt.log_base2k(), log_k, &mut have); - - println!("want: {:?}", &want[..16]); - println!("have: {:?}", &have[..16]); -} diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs deleted file mode 100644 index 5e5b48a..0000000 --- a/rlwe/src/automorphism.rs +++ /dev/null @@ -1,349 +0,0 @@ -use crate::{ - ciphertext::{Ciphertext, new_gadget_ciphertext}, - elem::ElemCommon, - encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}, - key_switching::{key_switch_rlwe, key_switch_rlwe_inplace, key_switch_tmp_bytes}, - keys::SecretKey, - parameters::Parameters, -}; -use base2k::{ - Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, - VmpPMatOps, assert_alignement, -}; -use sampling::source::Source; -use std::collections::HashMap; - -/// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p} -pub struct AutomorphismKey { - pub value: Ciphertext, - pub p: i64, -} - -pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) -} - -impl Parameters { - pub fn automorphism_key_new_tmp_bytes(&self, rows: usize, log_q: usize) -> usize { - automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q) - } - - pub fn automorphism_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - automorphism_tmp_bytes( - self.module(), - self.log_base2k(), - res_logq, - in_logq, - gct_logq, - ) - } -} - -impl AutomorphismKey { - pub fn new( - module: &Module, - p: i64, - sk: &SecretKey, - log_base2k: usize, - rows: usize, - log_q: usize, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], - ) -> Self { - Self::new_many_core( - module, - &vec![p], - sk, - log_base2k, - rows, - log_q, - source_xa, - source_xe, - sigma, - tmp_bytes, - ) - .into_iter() - .next() - .unwrap() - } - - pub fn new_many( - module: &Module, - p: &Vec, - sk: &SecretKey, - log_base2k: usize, - rows: usize, - log_q: usize, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], - ) -> HashMap { - Self::new_many_core( - module, p, sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes, - ) - .into_iter() - .zip(p.iter().cloned()) - .map(|(key, pi)| (pi, key)) - .collect() - } - - fn new_many_core( - module: &Module, - p: &Vec, - sk: &SecretKey, - log_base2k: usize, - rows: usize, - log_q: usize, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], - ) -> Vec { - let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar()); - let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol()); - - let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes); - let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); - - let mut keys: Vec = Vec::new(); - - p.iter().for_each(|pi| { - let mut value: Ciphertext = new_gadget_ciphertext(module, log_base2k, rows, log_q); - - let p_inv: i64 = module.galois_element_inv(*pi); - - module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx()); - module.svp_prepare(&mut sk_out, &sk_auto); - encrypt_grlwe_sk( - module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes, - ); - - keys.push(Self { - value: value, - p: *pi, - }) - }); - - keys - } -} - -pub fn automorphism_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - key_switch_tmp_bytes(module, log_base2k, res_logq, in_logq, gct_logq) -} - -pub fn automorphism( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &AutomorphismKey, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - key_switch_rlwe(module, c, a, &b.value, b_cols, tmp_bytes); - // c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)] - module.vec_znx_automorphism_inplace(b.p, c.at_mut(0)); - // c[1] = AUTO(b, p) - module.vec_znx_automorphism_inplace(b.p, c.at_mut(1)); -} - -pub fn automorphism_inplace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize { - return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols)); -} - -pub fn automorphism_inplace( - module: &Module, - a: &mut Ciphertext, - b: &AutomorphismKey, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - key_switch_rlwe_inplace(module, a, &b.value, b_cols, tmp_bytes); - // a[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)] - module.vec_znx_automorphism_inplace(b.p, a.at_mut(0)); - // a[1] = AUTO(b, p) - module.vec_znx_automorphism_inplace(b.p, a.at_mut(1)); -} - -pub fn automorphism_big( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &AutomorphismKey, - tmp_bytes: &mut [u8], -) { - let cols = std::cmp::min(c.cols(), a.cols()); - - #[cfg(debug_assertions)] - { - assert!(tmp_bytes.len() >= automorphism_tmp_bytes(module, c.cols(), a.cols(), b.value.rows(), b.value.cols())); - assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft); - - // a1_dft = DFT(a[1]) - module.vec_znx_dft(&mut a1_dft, a.at(1)); - - // res_dft = IDFT() = [-b*AUTO(s, -p) + a * s + e] - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes); - module.vec_znx_idft_tmp_a(c.at_mut(0), &mut res_dft); - - // res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e] - module.vec_znx_big_add_small_inplace(c.at_mut(0), a.at(0)); - - // c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)] - module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(0)); - - // res_dft = IDFT() = [b] - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes); - module.vec_znx_idft_tmp_a(c.at_mut(1), &mut res_dft); - - // c[1] = AUTO(b, p) - module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(1)); -} - -#[cfg(test)] -mod test { - use super::{AutomorphismKey, automorphism}; - use crate::{ - ciphertext::Ciphertext, - decryptor::decrypt_rlwe, - elem::ElemCommon, - encryptor::encrypt_rlwe_sk, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, - plaintext::Plaintext, - }; - use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned}; - use sampling::source::{Source, new_seed}; - - #[test] - fn test_automorphism() { - let log_base2k: usize = 10; - let log_q: usize = 50; - let log_p: usize = 15; - - // Basic parameters with enough limbs to test edge cases - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: 12, - log_q: log_q, - log_p: log_p, - log_base2k: log_base2k, - log_scale: 20, - xe: 3.2, - xs: 1 << 11, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let module: &Module = params.module(); - let log_q: usize = params.log_q(); - let log_qp: usize = params.log_qp(); - let gct_rows: usize = params.cols_q(); - let gct_cols: usize = params.cols_qp(); - - // scratch space - let mut tmp_bytes: Vec = alloc_aligned( - params.decrypt_rlwe_tmp_byte(log_q) - | params.encrypt_rlwe_sk_tmp_bytes(log_q) - | params.automorphism_key_new_tmp_bytes(gct_rows, log_qp) - | params.automorphism_tmp_bytes(log_q, log_q, log_qp), - ); - - // Samplers for public and private randomness - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - let mut source_xs: Source = Source::new(new_seed()); - - let mut sk: SecretKey = SecretKey::new(module); - sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); - module.svp_prepare(&mut sk_svp_ppol, &sk.0); - - let p: i64 = -5; - - let auto_key: AutomorphismKey = AutomorphismKey::new( - module, - p, - &sk, - log_base2k, - gct_rows, - log_qp, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - let mut data: Vec = vec![0i64; params.n()]; - - data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - - let log_k: usize = 2 * log_base2k; - - let mut ct: Ciphertext = params.new_ciphertext(log_q); - let mut pt: Plaintext = params.new_plaintext(log_q); - let mut pt_auto: Plaintext = params.new_plaintext(log_q); - - pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); - module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0)); - - encrypt_rlwe_sk( - module, - &mut ct.elem_mut(), - Some(pt.at(0)), - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - let mut ct_auto: Ciphertext = params.new_ciphertext(log_q); - - // ct <- AUTO(ct) - automorphism( - module, - &mut ct_auto, - &ct, - &auto_key, - gct_cols, - &mut tmp_bytes, - ); - - // pt = dec(auto(ct)) - auto(pt) - decrypt_rlwe( - module, - pt.elem_mut(), - ct_auto.elem(), - &sk_svp_ppol, - &mut tmp_bytes, - ); - - module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_auto.at(0)); - - // pt.at(0).print(pt.cols(), 16); - - let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); - - let var_msg: f64 = (params.xs() as f64) / params.n() as f64; - let var_a_err: f64 = 1f64 / 12f64; - - let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q()); - - println!("noise_pred: {}", noise_pred); - println!("noise_have: {}", noise_have); - - assert!(noise_have <= noise_pred + 1.0); - } -} diff --git a/rlwe/src/ciphertext.rs b/rlwe/src/ciphertext.rs deleted file mode 100644 index 9d1fe1a..0000000 --- a/rlwe/src/ciphertext.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::elem::{Elem, ElemCommon}; -use crate::parameters::Parameters; -use base2k::{Infos, LAYOUT, Module, VecZnx, VmpPMat}; - -pub struct Ciphertext(pub Elem); - -impl Parameters { - pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext { - Ciphertext::new(self.module(), self.log_base2k(), log_q, 2) - } -} - -impl ElemCommon for Ciphertext -where - T: Infos, -{ - fn n(&self) -> usize { - self.elem().n() - } - - fn log_n(&self) -> usize { - self.elem().log_n() - } - - fn log_q(&self) -> usize { - self.elem().log_q() - } - - fn elem(&self) -> &Elem { - &self.0 - } - - fn elem_mut(&mut self) -> &mut Elem { - &mut self.0 - } - - fn size(&self) -> usize { - self.elem().size() - } - - fn layout(&self) -> LAYOUT { - self.elem().layout() - } - - fn rows(&self) -> usize { - self.elem().rows() - } - - fn cols(&self) -> usize { - self.elem().cols() - } - - fn at(&self, i: usize) -> &T { - self.elem().at(i) - } - - fn at_mut(&mut self, i: usize) -> &mut T { - self.elem_mut().at_mut(i) - } - - fn log_base2k(&self) -> usize { - self.elem().log_base2k() - } - - fn log_scale(&self) -> usize { - self.elem().log_scale() - } -} - -impl Ciphertext { - pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self { - Self(Elem::::new(module, log_base2k, log_q, rows)) - } -} - -pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) -> Ciphertext { - let rows: usize = 2; - Ciphertext::::new(module, log_base2k, log_q, rows) -} - -pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 2, rows, cols); - elem.log_q = log_q; - Ciphertext(elem) -} - -pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext { - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut elem: Elem = Elem::::new(module, log_base2k, 4, rows, cols); - elem.log_q = log_q; - Ciphertext(elem) -} diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs deleted file mode 100644 index 6eeea27..0000000 --- a/rlwe/src/decryptor.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::{ - ciphertext::Ciphertext, - elem::{Elem, ElemCommon}, - keys::SecretKey, - parameters::Parameters, - plaintext::Plaintext, -}; -use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps}; -use std::cmp::min; - -pub struct Decryptor { - sk: SvpPPol, -} - -impl Decryptor { - pub fn new(params: &Parameters, sk: &SecretKey) -> Self { - let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol(); - sk.prepare(params.module(), &mut sk_svp_ppol); - Self { sk: sk_svp_ppol } - } -} - -pub fn decrypt_rlwe_tmp_byte(module: &Module, cols: usize) -> usize { - module.bytes_of_vec_znx_dft(1, cols) + module.vec_znx_big_normalize_tmp_bytes() -} - -impl Parameters { - pub fn decrypt_rlwe_tmp_byte(&self, log_q: usize) -> usize { - decrypt_rlwe_tmp_byte( - self.module(), - (log_q + self.log_base2k() - 1) / self.log_base2k(), - ) - } - - pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext, sk: &SvpPPol, tmp_bytes: &mut [u8]) { - decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes) - } -} - -pub fn decrypt_rlwe(module: &Module, res: &mut Elem, a: &Elem, sk: &SvpPPol, tmp_bytes: &mut [u8]) { - let cols: usize = a.cols(); - - assert!( - tmp_bytes.len() >= decrypt_rlwe_tmp_byte(module, cols), - "invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_tmp_byte={}", - tmp_bytes.len(), - decrypt_rlwe_tmp_byte(module, cols) - ); - - let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft); - let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big(); - - // res_dft <- DFT(ct[1]) * DFT(sk) - module.svp_apply_dft(&mut res_dft, sk, a.at(1)); - // res_big <- ct[1] x sk - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - // res_big <- ct[1] x sk + ct[0] - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); - // res <- normalize(ct[1] x sk + ct[0]) - module.vec_znx_big_normalize(a.log_base2k(), res.at_mut(0), &res_big, tmp_bytes_normalize); - - res.log_base2k = a.log_base2k(); - res.log_q = min(res.log_q(), a.log_q()); - res.log_scale = a.log_scale(); -} diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs deleted file mode 100644 index e7e61c4..0000000 --- a/rlwe/src/elem.rs +++ /dev/null @@ -1,168 +0,0 @@ -use base2k::{Infos, LAYOUT, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps}; - -pub struct Elem { - pub value: Vec, - pub log_base2k: usize, - pub log_q: usize, - pub log_scale: usize, -} - -pub trait ElemVecZnx { - fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem; - fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem; - fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize; - fn zero(&mut self); -} - -impl ElemVecZnx for Elem { - fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize { - let cols = (log_q + log_base2k - 1) / log_base2k; - module.n() * cols * size * 8 - } - - fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem { - assert!(size > 0); - let n: usize = module.n(); - assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); - let mut value: Vec = Vec::new(); - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let elem_size = VecZnx::bytes_of(n, size, cols); - let mut ptr: usize = 0; - (0..size).for_each(|_| { - value.push(VecZnx::from_bytes(n, 1, cols, &mut bytes[ptr..])); - ptr += elem_size - }); - Self { - value, - log_q, - log_base2k, - log_scale: 0, - } - } - - fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem { - assert!(size > 0); - let n: usize = module.n(); - assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); - let mut value: Vec = Vec::new(); - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let elem_size = VecZnx::bytes_of(n, 1, cols); - let mut ptr: usize = 0; - (0..size).for_each(|_| { - value.push(VecZnx::from_bytes_borrow(n, 1, cols, &mut bytes[ptr..])); - ptr += elem_size - }); - Self { - value, - log_q, - log_base2k, - log_scale: 0, - } - } - - fn zero(&mut self) { - self.value.iter_mut().for_each(|i| i.zero()); - } -} - -pub trait ElemCommon { - fn n(&self) -> usize; - fn log_n(&self) -> usize; - fn elem(&self) -> &Elem; - fn elem_mut(&mut self) -> &mut Elem; - fn size(&self) -> usize; - fn layout(&self) -> LAYOUT; - fn rows(&self) -> usize; - fn cols(&self) -> usize; - fn log_base2k(&self) -> usize; - fn log_q(&self) -> usize; - fn log_scale(&self) -> usize; - fn at(&self, i: usize) -> &T; - fn at_mut(&mut self, i: usize) -> &mut T; -} - -impl ElemCommon for Elem { - fn n(&self) -> usize { - self.value[0].n() - } - - fn log_n(&self) -> usize { - self.value[0].log_n() - } - - fn elem(&self) -> &Elem { - self - } - - fn elem_mut(&mut self) -> &mut Elem { - self - } - - fn size(&self) -> usize { - self.value.len() - } - - fn layout(&self) -> LAYOUT { - self.value[0].layout() - } - - fn rows(&self) -> usize { - self.value[0].rows() - } - - fn cols(&self) -> usize { - self.value[0].cols() - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_q(&self) -> usize { - self.log_q - } - - fn log_scale(&self) -> usize { - self.log_scale - } - - fn at(&self, i: usize) -> &T { - assert!(i < self.size()); - &self.value[i] - } - - fn at_mut(&mut self, i: usize) -> &mut T { - assert!(i < self.size()); - &mut self.value[i] - } -} - -impl Elem { - pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self { - assert!(rows > 0); - let cols: usize = (log_q + log_base2k - 1) / log_base2k; - let mut value: Vec = Vec::new(); - (0..rows).for_each(|_| value.push(module.new_vec_znx(1, cols))); - Self { - value, - log_q, - log_base2k, - log_scale: 0, - } - } -} - -impl Elem { - pub fn new(module: &Module, log_base2k: usize, size: usize, rows: usize, cols: usize) -> Self { - assert!(rows > 0); - assert!(cols > 0); - let mut value: Vec = Vec::new(); - (0..size).for_each(|_| value.push(module.new_vmp_pmat(1, rows, cols))); - Self { - value: value, - log_q: 0, - log_base2k: log_base2k, - log_scale: 0, - } - } -} diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs deleted file mode 100644 index bdb383c..0000000 --- a/rlwe/src/encryptor.rs +++ /dev/null @@ -1,369 +0,0 @@ -use crate::ciphertext::Ciphertext; -use crate::elem::{Elem, ElemCommon, ElemVecZnx}; -use crate::keys::SecretKey; -use crate::parameters::Parameters; -use crate::plaintext::Plaintext; -use base2k::sampling::Sampling; -use base2k::{ - Infos, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, - VmpPMatOps, -}; - -use sampling::source::{Source, new_seed}; - -impl Parameters { - pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize { - encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q) - } - pub fn encrypt_rlwe_sk( - &self, - ct: &mut Ciphertext, - pt: Option<&Plaintext>, - sk: &SvpPPol, - source_xa: &mut Source, - source_xe: &mut Source, - tmp_bytes: &mut [u8], - ) { - encrypt_rlwe_sk( - self.module(), - &mut ct.0, - pt.map(|pt| pt.at(0)), - sk, - source_xa, - source_xe, - self.xe(), - tmp_bytes, - ) - } -} - -pub struct EncryptorSk { - sk: SvpPPol, - source_xa: Source, - source_xe: Source, - initialized: bool, - tmp_bytes: Vec, -} - -impl EncryptorSk { - pub fn new(params: &Parameters, sk: Option<&SecretKey>) -> Self { - let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol(); - let mut initialized: bool = false; - if let Some(sk) = sk { - sk.prepare(params.module(), &mut sk_svp_ppol); - initialized = true; - } - Self { - sk: sk_svp_ppol, - initialized, - source_xa: Source::new(new_seed()), - source_xe: Source::new(new_seed()), - tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.cols_qp())], - } - } - - pub fn set_sk(&mut self, module: &Module, sk: &SecretKey) { - sk.prepare(module, &mut self.sk); - self.initialized = true; - } - - pub fn seed_source_xa(&mut self, seed: [u8; 32]) { - self.source_xa = Source::new(seed) - } - - pub fn seed_source_xe(&mut self, seed: [u8; 32]) { - self.source_xe = Source::new(seed) - } - - pub fn encrypt_rlwe_sk(&mut self, params: &Parameters, ct: &mut Ciphertext, pt: Option<&Plaintext>) { - assert!( - self.initialized == true, - "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]" - ); - params.encrypt_rlwe_sk( - ct, - pt, - &self.sk, - &mut self.source_xa, - &mut self.source_xe, - &mut self.tmp_bytes, - ); - } - - pub fn encrypt_rlwe_sk_core( - &self, - params: &Parameters, - ct: &mut Ciphertext, - pt: Option<&Plaintext>, - source_xa: &mut Source, - source_xe: &mut Source, - tmp_bytes: &mut [u8], - ) { - assert!( - self.initialized == true, - "invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]" - ); - params.encrypt_rlwe_sk(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes); - } -} - -pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize { - module.bytes_of_vec_znx_dft(1, (log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes() -} -pub fn encrypt_rlwe_sk( - module: &Module, - ct: &mut Elem, - pt: Option<&VecZnx>, - sk: &SvpPPol, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - encrypt_rlwe_sk_core::<0>(module, ct, pt, sk, source_xa, source_xe, sigma, tmp_bytes) -} - -fn encrypt_rlwe_sk_core( - module: &Module, - ct: &mut Elem, - pt: Option<&VecZnx>, - sk: &SvpPPol, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - let cols: usize = ct.cols(); - let log_base2k: usize = ct.log_base2k(); - let log_q: usize = ct.log_q(); - - assert!( - tmp_bytes.len() >= encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q), - "invalid tmp_bytes: tmp_bytes={} < encrypt_rlwe_sk_tmp_bytes={}", - tmp_bytes.len(), - encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - ); - - let log_q: usize = ct.log_q(); - let log_base2k: usize = ct.log_base2k(); - let c1: &mut VecZnx = ct.at_mut(1); - - // c1 <- Z_{2^prec}[X]/(X^{N}+1) - module.fill_uniform(log_base2k, c1, cols, source_xa); - - let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - // Scratch space for DFT values - let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft); - - // Applies buf_dft <- DFT(s) * DFT(c1) - module.svp_apply_dft(&mut buf_dft, sk, c1); - - // Alias scratch space - let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); - - // buf_big = s x c1 - module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - - match PT_POS { - // c0 <- -s x c1 + m - 0 => { - let c0: &mut VecZnx = ct.at_mut(0); - if let Some(pt) = pt { - module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt); - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); - } else { - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); - module.vec_znx_negate_inplace(c0); - } - } - // c1 <- c1 + m - 1 => { - if let Some(pt) = pt { - module.vec_znx_add_inplace(c1, pt); - c1.normalize(log_base2k, tmp_bytes_normalize); - } - let c0: &mut VecZnx = ct.at_mut(0); - module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize); - module.vec_znx_negate_inplace(c0); - } - _ => panic!("PT_POS must be 1 or 2"), - } - - // c0 <- -s x c1 + m + e - module.add_normal( - log_base2k, - ct.at_mut(0), - log_q, - source_xe, - sigma, - (sigma * 6.0).ceil(), - ); -} - -impl Parameters { - pub fn encrypt_grlwe_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize { - encrypt_grlwe_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q) - } -} - -pub fn encrypt_grlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - let cols = (log_q + log_base2k - 1) / log_base2k; - Elem::::bytes_of(module, log_base2k, log_q, 2) - + Plaintext::bytes_of(module, log_base2k, log_q) - + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - + module.vmp_prepare_tmp_bytes(rows, cols) -} - -pub fn encrypt_grlwe_sk( - module: &Module, - ct: &mut Ciphertext, - m: &Scalar, - sk: &SvpPPol, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - let log_q: usize = ct.log_q(); - let log_base2k: usize = ct.log_base2k(); - let (left, right) = ct.0.value.split_at_mut(1); - encrypt_grlwe_sk_core::<0>( - module, - log_base2k, - [&mut left[0], &mut right[0]], - log_q, - m, - sk, - source_xa, - source_xe, - sigma, - tmp_bytes, - ) -} - -impl Parameters { - pub fn encrypt_rgsw_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize { - encrypt_rgsw_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q) - } -} - -pub fn encrypt_rgsw_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - let cols = (log_q + log_base2k - 1) / log_base2k; - Elem::::bytes_of(module, log_base2k, log_q, 2) - + Plaintext::bytes_of(module, log_base2k, log_q) - + encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - + module.vmp_prepare_tmp_bytes(rows, cols) -} - -pub fn encrypt_rgsw_sk( - module: &Module, - ct: &mut Ciphertext, - m: &Scalar, - sk: &SvpPPol, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - let log_q: usize = ct.log_q(); - let log_base2k: usize = ct.log_base2k(); - - let (left, right) = ct.0.value.split_at_mut(2); - let (ll, lr) = left.split_at_mut(1); - let (rl, rr) = right.split_at_mut(1); - - encrypt_grlwe_sk_core::<0>( - module, - log_base2k, - [&mut ll[0], &mut lr[0]], - log_q, - m, - sk, - source_xa, - source_xe, - sigma, - tmp_bytes, - ); - encrypt_grlwe_sk_core::<1>( - module, - log_base2k, - [&mut rl[0], &mut rr[0]], - log_q, - m, - sk, - source_xa, - source_xe, - sigma, - tmp_bytes, - ); -} - -fn encrypt_grlwe_sk_core( - module: &Module, - log_base2k: usize, - mut ct: [&mut VmpPMat; 2], - log_q: usize, - m: &Scalar, - sk: &SvpPPol, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - let rows: usize = ct[0].rows(); - - let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q); - - assert!( - tmp_bytes.len() >= min_tmp_bytes_len, - "invalid tmp_bytes: tmp_bytes.len()={} < encrypt_grlwe_sk_tmp_bytes={}", - tmp_bytes.len(), - min_tmp_bytes_len - ); - - let bytes_of_elem: usize = Elem::::bytes_of(module, log_base2k, log_q, 2); - let bytes_of_pt: usize = Plaintext::bytes_of(module, log_base2k, log_q); - let bytes_of_enc_sk: usize = encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q); - - let (tmp_bytes_pt, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_pt); - let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk); - let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem); - - let mut tmp_elem: Elem = Elem::::from_bytes_borrow(module, log_base2k, log_q, 2, tmp_bytes_elem); - let mut tmp_pt: Plaintext = Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt); - - (0..rows).for_each(|row_i| { - // Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i}) - tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw()); - - // Encrypts RLWE(m * 2^{-log_base2k*i}) - encrypt_rlwe_sk_core::( - module, - &mut tmp_elem, - Some(tmp_pt.at(0)), - sk, - source_xa, - source_xe, - sigma, - tmp_bytes_enc_sk, - ); - - // Zeroes the ith-row of tmp_pt - tmp_pt.at_mut(0).at_mut(row_i).fill(0); - - // GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a] - module.vmp_prepare_row( - ct[0], - tmp_elem.at(0).raw(), - row_i, - tmp_bytes_vmp_prepare_row, - ); - module.vmp_prepare_row( - &mut ct[1], - tmp_elem.at(1).raw(), - row_i, - tmp_bytes_vmp_prepare_row, - ); - }); -} diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs deleted file mode 100644 index bbf9642..0000000 --- a/rlwe/src/gadget_product.rs +++ /dev/null @@ -1,383 +0,0 @@ -use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps}; -use std::cmp::min; - -pub fn gadget_product_core_tmp_bytes( - module: &Module, - log_base2k: usize, - res_log_q: usize, - in_log_q: usize, - gct_rows: usize, - gct_log_q: usize, -) -> usize { - let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k; - let in_cols: usize = (in_log_q + log_base2k - 1) / log_base2k; - let out_cols: usize = (res_log_q + log_base2k - 1) / log_base2k; - module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, gct_rows, gct_cols) -} - -impl Parameters { - pub fn gadget_product_tmp_bytes(&self, res_log_q: usize, in_log_q: usize, gct_rows: usize, gct_log_q: usize) -> usize { - gadget_product_core_tmp_bytes( - self.module(), - self.log_base2k(), - res_log_q, - in_log_q, - gct_rows, - gct_log_q, - ) - } -} - -pub fn gadget_product_core( - module: &Module, - res_dft_0: &mut VecZnxDft, - res_dft_1: &mut VecZnxDft, - a: &VecZnx, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - assert!(b_cols <= b.cols()); - module.vec_znx_dft(res_dft_1, a); - module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes); - module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes); -} - -pub fn gadget_product_big_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize { - return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(1, min(c_cols, a_cols)); -} - -/// Evaluates the gadget product: c.at(i) = IDFT() -/// -/// # Arguments -/// -/// * `module`: backend support for operations mod (X^N + 1). -/// * `c`: a [Ciphertext] with cols_c cols. -/// * `a`: a [Ciphertext] with cols_a cols. -/// * `b`: a [Ciphertext] with at least min(cols_c, cols_a) rows. -pub fn gadget_product_big( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &Ciphertext, - tmp_bytes: &mut [u8], -) { - let cols: usize = min(c.cols(), a.cols()); - - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft); - - // a1_dft = DFT(a[1]) - module.vec_znx_dft(&mut a1_dft, a.at(1)); - - // c[i] = IDFT(DFT(a[1]) * b[i]) - (0..2).for_each(|i| { - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes); - module.vec_znx_idft_tmp_a(c.at_mut(i), &mut res_dft); - }) -} - -/// Evaluates the gadget product: c.at(i) = NORMALIZE(IDFT() -/// -/// # Arguments -/// -/// * `module`: backend support for operations mod (X^N + 1). -/// * `c`: a [Ciphertext] with cols_c cols. -/// * `a`: a [Ciphertext] with cols_a cols. -/// * `b`: a [Ciphertext] with at least min(cols_c, cols_a) rows. -pub fn gadget_product( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &Ciphertext, - tmp_bytes: &mut [u8], -) { - let cols: usize = min(c.cols(), a.cols()); - - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft); - let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); - - // a1_dft = DFT(a[1]) - module.vec_znx_dft(&mut a1_dft, a.at(1)); - - // c[i] = IDFT(DFT(a[1]) * b[i]) - (0..2).for_each(|i| { - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(i), &mut res_big, tmp_bytes); - }) -} - -#[cfg(test)] -mod test { - use crate::{ - ciphertext::{Ciphertext, new_gadget_ciphertext}, - decryptor::decrypt_rlwe, - elem::{Elem, ElemCommon, ElemVecZnx}, - encryptor::encrypt_grlwe_sk, - gadget_product::gadget_product_core, - keys::SecretKey, - parameters::{Parameters, ParametersLiteral}, - plaintext::Plaintext, - }; - use base2k::{ - BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, - alloc_aligned_u8, - }; - use sampling::source::{Source, new_seed}; - - #[test] - fn test_gadget_product_core() { - let log_base2k: usize = 10; - let q_cols: usize = 7; - let p_cols: usize = 1; - - // Basic parameters with enough limbs to test edge cases - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: 12, - log_q: q_cols * log_base2k, - log_p: p_cols * log_base2k, - log_base2k: log_base2k, - log_scale: 20, - xe: 3.2, - xs: 1 << 11, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - // scratch space - let mut tmp_bytes: Vec = alloc_aligned_u8( - params.decrypt_rlwe_tmp_byte(params.log_qp()) - | params.gadget_product_tmp_bytes( - params.log_qp(), - params.log_qp(), - params.cols_qp(), - params.log_qp(), - ) - | params.encrypt_grlwe_sk_tmp_bytes(params.cols_qp(), params.log_qp()), - ); - - // Samplers for public and private randomness - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - let mut source_xs: Source = Source::new(new_seed()); - - // Two secret keys - let mut sk0: SecretKey = SecretKey::new(params.module()); - sk0.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0); - - let mut sk1: SecretKey = SecretKey::new(params.module()); - sk1.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol(); - params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0); - - // The gadget ciphertext - let mut gadget_ct: Ciphertext = new_gadget_ciphertext( - params.module(), - log_base2k, - params.cols_qp(), - params.log_qp(), - ); - - // gct = [-b*sk1 + g(sk0) + e, b] - encrypt_grlwe_sk( - params.module(), - &mut gadget_ct, - &sk0.0, - &sk1_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - // Intermediate buffers - - // Input polynopmial, uniformly distributed - let mut a: VecZnx = params.module().new_vec_znx(1, params.cols_q()); - params - .module() - .fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa); - - // res = g^-1(a) * gct - let mut elem_res: Elem = Elem::::new(params.module(), log_base2k, params.log_qp(), 2); - - // Ideal output = a * s - let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(1, a.cols()); - let mut a_big: VecZnxBig = a_dft.as_vec_znx_big(); - let mut a_times_s: VecZnx = params.module().new_vec_znx(1, a.cols()); - - // a * sk0 - params.module().svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a); - params.module().vec_znx_idft_tmp_a(&mut a_big, &mut a_dft); - params - .module() - .vec_znx_big_normalize(params.log_base2k(), &mut a_times_s, &a_big, &mut tmp_bytes); - - // Plaintext for decrypted output of gadget product - let mut pt: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_qp()); - - // Iterates over all possible cols values for input/output polynomials and gadget ciphertext. - - (1..a.cols() + 1).for_each(|a_cols| { - let mut a_trunc: VecZnx = params.module().new_vec_znx(1, a_cols); - a_trunc.copy_from(&a); - - (1..gadget_ct.cols() + 1).for_each(|b_cols| { - let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols); - let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols); - let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big(); - let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big(); - - pt.elem_mut().zero(); - elem_res.zero(); - - // let b_cols: usize = min(a_cols+1, gadget_ct.cols()); - - println!("a_cols: {} b_cols: {}", a_cols, b_cols); - - // res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e') - // res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c) - gadget_product_core( - params.module(), - &mut res_dft_0, - &mut res_dft_1, - &a_trunc, - &gadget_ct, - b_cols, - &mut tmp_bytes, - ); - - // res_big_0 = IDFT(res_dft_0) - params - .module() - .vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0); - // res_big_1 = IDFT(res_dft_1); - params - .module() - .vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1); - - // res_big_0 = normalize(res_big_0) - params - .module() - .vec_znx_big_normalize(log_base2k, elem_res.at_mut(0), &res_big_0, &mut tmp_bytes); - - // res_big_1 = normalize(res_big_1) - params - .module() - .vec_znx_big_normalize(log_base2k, elem_res.at_mut(1), &res_big_1, &mut tmp_bytes); - - // <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e - decrypt_rlwe( - params.module(), - pt.elem_mut(), - &elem_res, - &sk1_svp_ppol, - &mut tmp_bytes, - ); - - // a * sk0 + e - a*sk0 = e - params - .module() - .vec_znx_sub_ab_inplace(pt.at_mut(0), &mut a_times_s); - pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); - - // pt.at(0).print(pt.elem().cols(), 16); - - let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); - - let var_a_err: f64; - - if a_cols < a.cols() { - var_a_err = 1f64 / 12f64; - } else { - var_a_err = 0f64; - } - - let a_logq: usize = a_cols * log_base2k; - let b_logq: usize = b_cols * log_base2k; - let var_msg: f64 = (params.xs() as f64) / params.n() as f64; - - println!("{} {} {} {}", var_msg, var_a_err, a_logq, b_logq); - - let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq); - - println!("noise_pred: {}", noise_pred); - println!("noise_have: {}", noise_have); - - // assert!(noise_have <= noise_pred + 1.0); - }); - }); - } -} - -impl Parameters { - pub fn noise_grlwe_product(&self, var_msg: f64, var_a_err: f64, a_logq: usize, b_logq: usize) -> f64 { - let n: f64 = self.n() as f64; - let var_xs: f64 = self.xs() as f64; - - let var_gct_err_lhs: f64; - let var_gct_err_rhs: f64; - if b_logq < self.log_qp() { - let var_round: f64 = 1f64 / 12f64; - var_gct_err_lhs = var_round; - var_gct_err_rhs = var_round; - } else { - var_gct_err_lhs = self.xe() * self.xe(); - var_gct_err_rhs = 0f64; - } - - noise_grlwe_product( - n, - self.log_base2k(), - var_xs, - var_msg, - var_a_err, - var_gct_err_lhs, - var_gct_err_rhs, - a_logq, - b_logq, - ) - } -} - -pub fn noise_grlwe_product( - n: f64, - log_base2k: usize, - var_xs: f64, - var_msg: f64, - var_a_err: f64, - var_gct_err_lhs: f64, - var_gct_err_rhs: f64, - a_logq: usize, - b_logq: usize, -) -> f64 { - let a_logq: usize = min(a_logq, b_logq); - let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; - - let b_scale = 2.0f64.powi(b_logq as i32); - let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - - let base: f64 = (1 << (log_base2k)) as f64; - let var_base: f64 = base * base / 12f64; - - // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) - // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs - let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); - noise += var_msg * var_a_err * a_scale * a_scale * n; - noise = noise.sqrt(); - noise /= b_scale; - noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] -} diff --git a/rlwe/src/key_generator.rs b/rlwe/src/key_generator.rs deleted file mode 100644 index 4f62a2c..0000000 --- a/rlwe/src/key_generator.rs +++ /dev/null @@ -1,55 +0,0 @@ -use crate::encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes}; -use crate::keys::{PublicKey, SecretKey, SwitchingKey}; -use crate::parameters::Parameters; -use base2k::{Module, SvpPPol}; -use sampling::source::Source; - -pub struct KeyGenerator {} - -impl KeyGenerator { - pub fn gen_secret_key_thread_safe(&self, params: &Parameters, source: &mut Source) -> SecretKey { - let mut sk: SecretKey = SecretKey::new(params.module()); - sk.fill_ternary_hw(params.xs(), source); - sk - } - - pub fn gen_public_key_thread_safe( - &self, - params: &Parameters, - sk_ppol: &SvpPPol, - source: &mut Source, - tmp_bytes: &mut [u8], - ) -> PublicKey { - let mut xa_source: Source = source.branch(); - let mut xe_source: Source = source.branch(); - let mut pk: PublicKey = PublicKey::new(params.module(), params.log_base2k(), params.log_qp()); - pk.gen_thread_safe( - params.module(), - sk_ppol, - params.xe(), - &mut xa_source, - &mut xe_source, - tmp_bytes, - ); - pk - } -} - -pub fn gen_switching_key_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { - encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) -} - -pub fn gen_switching_key( - module: &Module, - swk: &mut SwitchingKey, - sk_in: &SecretKey, - sk_out: &SvpPPol, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - tmp_bytes: &mut [u8], -) { - encrypt_grlwe_sk( - module, &mut swk.0, &sk_in.0, sk_out, source_xa, source_xe, sigma, tmp_bytes, - ); -} diff --git a/rlwe/src/key_switching.rs b/rlwe/src/key_switching.rs deleted file mode 100644 index 4e0001a..0000000 --- a/rlwe/src/key_switching.rs +++ /dev/null @@ -1,79 +0,0 @@ -use crate::ciphertext::Ciphertext; -use crate::elem::ElemCommon; -use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; -use std::cmp::min; - -pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k; - let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k; - let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k; - return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols) - + module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols)) - + module.bytes_of_vec_znx_dft(1, gct_cols); -} - -pub fn key_switch_rlwe( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - key_switch_rlwe_core(module, c, a, b, b_cols, tmp_bytes); -} - -pub fn key_switch_rlwe_inplace( - module: &Module, - a: &mut Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - key_switch_rlwe_core(module, a, a, b, b_cols, tmp_bytes); -} - -fn key_switch_rlwe_core( - module: &Module, - c: *mut Ciphertext, - a: *const Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - // SAFETY WARNING: must ensure `c` and `a` are valid for read/write - let c: &mut Ciphertext = unsafe { &mut *c }; - let a: &Ciphertext = unsafe { &*a }; - - let cols: usize = min(min(c.cols(), a.cols()), b.rows()); - - #[cfg(debug_assertions)] - { - assert!(b_cols <= b.cols()); - assert!(tmp_bytes.len() >= key_switch_tmp_bytes(module, c.cols(), a.cols(), b.rows(), b.cols())); - assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - - let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_a1_dft); - let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft); - let mut res_big = res_dft.as_vec_znx_big(); - - module.vec_znx_dft(&mut a1_dft, a.at(1)); - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(0), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut res_big, tmp_bytes); - - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(1), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes); -} - -pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} - -pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext, a: &Ciphertext, b: &Ciphertext) {} diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs deleted file mode 100644 index da7c412..0000000 --- a/rlwe/src/keys.rs +++ /dev/null @@ -1,82 +0,0 @@ -use crate::ciphertext::{Ciphertext, new_gadget_ciphertext}; -use crate::elem::{Elem, ElemCommon}; -use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes}; -use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat}; -use sampling::source::Source; - -pub struct SecretKey(pub Scalar); - -impl SecretKey { - pub fn new(module: &Module) -> Self { - SecretKey(Scalar::new(module.n())) - } - - pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - self.0.fill_ternary_prob(prob, source); - } - - pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - self.0.fill_ternary_hw(hw, source); - } - - pub fn prepare(&self, module: &Module, sk_ppol: &mut SvpPPol) { - module.svp_prepare(sk_ppol, &self.0) - } -} - -pub struct PublicKey(pub Elem); - -impl PublicKey { - pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> PublicKey { - PublicKey(Elem::::new(module, log_base2k, log_q, 2)) - } - - pub fn gen_thread_safe( - &mut self, - module: &Module, - sk: &SvpPPol, - xe: f64, - xa_source: &mut Source, - xe_source: &mut Source, - tmp_bytes: &mut [u8], - ) { - encrypt_rlwe_sk( - module, - &mut self.0, - None, - sk, - xa_source, - xe_source, - xe, - tmp_bytes, - ); - } - - pub fn gen_thread_safe_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize { - encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q) - } -} - -pub struct SwitchingKey(pub Ciphertext); - -impl SwitchingKey { - pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> SwitchingKey { - SwitchingKey(new_gadget_ciphertext(module, log_base2k, rows, log_q)) - } - - pub fn n(&self) -> usize { - self.0.n() - } - - pub fn rows(&self) -> usize { - self.0.rows() - } - - pub fn cols(&self) -> usize { - self.0.cols() - } - - pub fn log_base2k(&self) -> usize { - self.0.log_base2k() - } -} diff --git a/rlwe/src/lib.rs b/rlwe/src/lib.rs deleted file mode 100644 index aecb526..0000000 --- a/rlwe/src/lib.rs +++ /dev/null @@ -1,13 +0,0 @@ -pub mod automorphism; -pub mod ciphertext; -pub mod decryptor; -pub mod elem; -pub mod encryptor; -pub mod gadget_product; -pub mod key_generator; -pub mod key_switching; -pub mod keys; -pub mod parameters; -pub mod plaintext; -pub mod rgsw_product; -pub mod trace; diff --git a/rlwe/src/parameters.rs b/rlwe/src/parameters.rs deleted file mode 100644 index cd3a91d..0000000 --- a/rlwe/src/parameters.rs +++ /dev/null @@ -1,88 +0,0 @@ -use base2k::module::{BACKEND, Module}; - -pub const DEFAULT_SIGMA: f64 = 3.2; - -pub struct ParametersLiteral { - pub backend: BACKEND, - pub log_n: usize, - pub log_q: usize, - pub log_p: usize, - pub log_base2k: usize, - pub log_scale: usize, - pub xe: f64, - pub xs: usize, -} - -pub struct Parameters { - log_n: usize, - log_q: usize, - log_p: usize, - log_scale: usize, - log_base2k: usize, - xe: f64, - xs: usize, - module: Module, -} - -impl Parameters { - pub fn new(p: &ParametersLiteral) -> Self { - assert!( - p.log_n + 2 * p.log_base2k <= 53, - "invalid parameters: p.log_n + 2*p.log_base2k > 53" - ); - Self { - log_n: p.log_n, - log_q: p.log_q, - log_p: p.log_p, - log_scale: p.log_scale, - log_base2k: p.log_base2k, - xe: p.xe, - xs: p.xs, - module: Module::new(1 << p.log_n, p.backend), - } - } - - pub fn n(&self) -> usize { - 1 << self.log_n - } - - pub fn log_scale(&self) -> usize { - self.log_scale - } - - pub fn log_q(&self) -> usize { - self.log_q - } - - pub fn log_p(&self) -> usize { - self.log_p - } - - pub fn log_qp(&self) -> usize { - self.log_q + self.log_p - } - - pub fn cols_q(&self) -> usize { - (self.log_q + self.log_base2k - 1) / self.log_base2k - } - - pub fn cols_qp(&self) -> usize { - (self.log_q + self.log_p + self.log_base2k - 1) / self.log_base2k - } - - pub fn log_base2k(&self) -> usize { - self.log_base2k - } - - pub fn module(&self) -> &Module { - &self.module - } - - pub fn xe(&self) -> f64 { - self.xe - } - - pub fn xs(&self) -> usize { - self.xs - } -} diff --git a/rlwe/src/plaintext.rs b/rlwe/src/plaintext.rs deleted file mode 100644 index 86f7e32..0000000 --- a/rlwe/src/plaintext.rs +++ /dev/null @@ -1,109 +0,0 @@ -use crate::ciphertext::Ciphertext; -use crate::elem::{Elem, ElemCommon, ElemVecZnx}; -use crate::parameters::Parameters; -use base2k::{LAYOUT, Module, VecZnx}; - -pub struct Plaintext(pub Elem); - -impl Parameters { - pub fn new_plaintext(&self, log_q: usize) -> Plaintext { - Plaintext::new(self.module(), self.log_base2k(), log_q) - } - - pub fn bytes_of_plaintext(&self, log_q: usize) -> usize -where { - Elem::::bytes_of(self.module(), self.log_base2k(), log_q, 1) - } - - pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext { - Plaintext(Elem::::from_bytes( - self.module(), - self.log_base2k(), - log_q, - 1, - bytes, - )) - } -} - -impl Plaintext { - pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { - Self(Elem::::new(module, log_base2k, log_q, 1)) - } -} - -impl Plaintext { - pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize { - Elem::::bytes_of(module, log_base2k, log_q, 1) - } - - pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self { - Self(Elem::::from_bytes( - module, log_base2k, log_q, 1, bytes, - )) - } - - pub fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self { - Self(Elem::::from_bytes_borrow( - module, log_base2k, log_q, 1, bytes, - )) - } - - pub fn as_ciphertext(&self) -> Ciphertext { - unsafe { Ciphertext::(std::ptr::read(&self.0)) } - } -} - -impl ElemCommon for Plaintext { - fn n(&self) -> usize { - self.0.n() - } - - fn log_n(&self) -> usize { - self.elem().log_n() - } - - fn log_q(&self) -> usize { - self.0.log_q - } - - fn elem(&self) -> &Elem { - &self.0 - } - - fn elem_mut(&mut self) -> &mut Elem { - &mut self.0 - } - - fn size(&self) -> usize { - self.elem().size() - } - - fn layout(&self) -> LAYOUT { - self.elem().layout() - } - - fn rows(&self) -> usize { - self.0.rows() - } - - fn cols(&self) -> usize { - self.0.cols() - } - - fn at(&self, i: usize) -> &VecZnx { - self.0.at(i) - } - - fn at_mut(&mut self, i: usize) -> &mut VecZnx { - self.0.at_mut(i) - } - - fn log_base2k(&self) -> usize { - self.0.log_base2k() - } - - fn log_scale(&self) -> usize { - self.0.log_scale() - } -} diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs deleted file mode 100644 index dc42602..0000000 --- a/rlwe/src/rgsw_product.rs +++ /dev/null @@ -1,300 +0,0 @@ -use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; -use std::cmp::min; - -impl Parameters { - pub fn rgsw_product_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - rgsw_product_tmp_bytes( - self.module(), - self.log_base2k(), - res_logq, - in_logq, - gct_logq, - ) - } -} -pub fn rgsw_product_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k; - let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k; - let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k; - return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols) - + module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols)) - + 2 * module.bytes_of_vec_znx_dft(1, gct_cols); -} - -pub fn rgsw_product( - module: &Module, - c: &mut Ciphertext, - a: &Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - #[cfg(debug_assertions)] - { - assert!(b_cols <= b.cols()); - assert_eq!(c.size(), 2); - assert_eq!(a.size(), 2); - assert_eq!(b.size(), 4); - assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, c.cols(), a.cols(), min(b.rows(), a.cols()), b_cols)); - assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols())); - let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - - let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft); - let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft); - - let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big(); - let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big(); - - module.vec_znx_dft(&mut ai_dft, a.at(0)); - module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes); - module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes); - - module.vec_znx_dft(&mut ai_dft, a.at(1)); - module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes); - module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes); - - module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft); - module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft); - - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut c0_big, tmp_bytes); - module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut c1_big, tmp_bytes); -} - -pub fn rgsw_product_inplace( - module: &Module, - a: &mut Ciphertext, - b: &Ciphertext, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - #[cfg(debug_assertions)] - { - assert!(b_cols <= b.cols()); - assert_eq!(a.size(), 2); - assert_eq!(b.size(), 4); - assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, a.cols(), a.cols(), min(b.rows(), a.cols()), b_cols)); - assert_alignement(tmp_bytes.as_ptr()); - } - - let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols())); - let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - - let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft); - let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft); - let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft); - - let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big(); - let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big(); - - module.vec_znx_dft(&mut ai_dft, a.at(0)); - module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes); - module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes); - - module.vec_znx_dft(&mut ai_dft, a.at(1)); - module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes); - module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes); - - module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft); - module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft); - - module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut c0_big, tmp_bytes); - module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut c1_big, tmp_bytes); -} - -#[cfg(test)] -mod test { - use crate::{ - ciphertext::{Ciphertext, new_rgsw_ciphertext}, - decryptor::decrypt_rlwe, - elem::ElemCommon, - encryptor::{encrypt_rgsw_sk, encrypt_rlwe_sk}, - keys::SecretKey, - parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral}, - plaintext::Plaintext, - rgsw_product::rgsw_product_inplace, - }; - use base2k::{BACKEND, Encoding, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, VmpPMat, alloc_aligned}; - use sampling::source::{Source, new_seed}; - - #[test] - fn test_rgsw_product() { - let log_base2k: usize = 10; - let log_q: usize = 50; - let log_p: usize = 15; - - // Basic parameters with enough limbs to test edge cases - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: 12, - log_q: log_q, - log_p: log_p, - log_base2k: log_base2k, - log_scale: 20, - xe: 3.2, - xs: 1 << 11, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let module: &Module = params.module(); - let log_q: usize = params.log_q(); - let log_qp: usize = params.log_qp(); - let gct_rows: usize = params.cols_q(); - let gct_cols: usize = params.cols_qp(); - - // scratch space - let mut tmp_bytes: Vec = alloc_aligned( - params.decrypt_rlwe_tmp_byte(log_q) - | params.encrypt_rlwe_sk_tmp_bytes(log_q) - | params.rgsw_product_tmp_bytes(log_q, log_q, log_qp) - | params.encrypt_rgsw_sk_tmp_bytes(gct_rows, log_qp), - ); - - // Samplers for public and private randomness - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - let mut source_xs: Source = Source::new(new_seed()); - - let mut sk: SecretKey = SecretKey::new(module); - sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); - module.svp_prepare(&mut sk_svp_ppol, &sk.0); - - let mut ct_rgsw: Ciphertext = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp); - - let k: i64 = 3; - - // X^k - let m: Scalar = module.new_scalar(); - let data: &mut [i64] = m.raw_mut(); - data[k as usize] = 1; - - encrypt_rgsw_sk( - module, - &mut ct_rgsw, - &m, - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - DEFAULT_SIGMA, - &mut tmp_bytes, - ); - - let log_k: usize = 2 * log_base2k; - - let mut ct: Ciphertext = params.new_ciphertext(log_q); - let mut pt: Plaintext = params.new_plaintext(log_q); - let mut pt_rotate: Plaintext = params.new_plaintext(log_q); - - pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); - - module.vec_znx_rotate(k, pt_rotate.at_mut(0), pt.at_mut(0)); - - encrypt_rlwe_sk( - module, - &mut ct.elem_mut(), - Some(pt.at(0)), - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - rgsw_product_inplace(module, &mut ct, &ct_rgsw, gct_cols, &mut tmp_bytes); - - decrypt_rlwe( - module, - pt.elem_mut(), - ct.elem(), - &sk_svp_ppol, - &mut tmp_bytes, - ); - - module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_rotate.at(0)); - - // pt.at(0).print(pt.cols(), 16); - - let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); - - let var_msg: f64 = 1f64 / params.n() as f64; // X^{k} - let var_a0_err: f64 = params.xe() * params.xe(); - let var_a1_err: f64 = 1f64 / 12f64; - - let noise_pred: f64 = params.noise_rgsw_product(var_msg, var_a0_err, var_a1_err, ct.log_q(), ct_rgsw.log_q()); - - println!("noise_pred: {}", noise_pred); - println!("noise_have: {}", noise_have); - - assert!(noise_have <= noise_pred + 1.0); - } -} - -impl Parameters { - pub fn noise_rgsw_product(&self, var_msg: f64, var_a0_err: f64, var_a1_err: f64, a_logq: usize, b_logq: usize) -> f64 { - let n: f64 = self.n() as f64; - let var_xs: f64 = self.xs() as f64; - - let var_gct_err_lhs: f64; - let var_gct_err_rhs: f64; - if b_logq < self.log_qp() { - let var_round: f64 = 1f64 / 12f64; - var_gct_err_lhs = var_round; - var_gct_err_rhs = var_round; - } else { - var_gct_err_lhs = self.xe() * self.xe(); - var_gct_err_rhs = 0f64; - } - - noise_rgsw_product( - n, - self.log_base2k(), - var_xs, - var_msg, - var_a0_err, - var_a1_err, - var_gct_err_lhs, - var_gct_err_rhs, - a_logq, - b_logq, - ) - } -} - -pub fn noise_rgsw_product( - n: f64, - log_base2k: usize, - var_xs: f64, - var_msg: f64, - var_a0_err: f64, - var_a1_err: f64, - var_gct_err_lhs: f64, - var_gct_err_rhs: f64, - a_logq: usize, - b_logq: usize, -) -> f64 { - let a_logq: usize = min(a_logq, b_logq); - let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k; - - let b_scale = 2.0f64.powi(b_logq as i32); - let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32); - - let base: f64 = (1 << (log_base2k)) as f64; - let var_base: f64 = base * base / 12f64; - - // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) - // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs - let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); - noise += var_msg * var_a0_err * a_scale * a_scale * n; - noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs; - noise = noise.sqrt(); - noise /= b_scale; - noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] -} diff --git a/rlwe/src/test.rs b/rlwe/src/test.rs deleted file mode 100644 index 2a7e9d0..0000000 --- a/rlwe/src/test.rs +++ /dev/null @@ -1,113 +0,0 @@ -use base2k::{alloc_aligned, SvpPPol, SvpPPolOps, VecZnx, BACKEND}; -use sampling::source::{Source, new_seed}; -use crate::{ciphertext::Ciphertext, decryptor::decrypt_rlwe, elem::ElemCommon, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral, DEFAULT_SIGMA}, plaintext::Plaintext}; - - - -pub struct Context{ - pub params: Parameters, - pub sk0: SecretKey, - pub sk0_ppol:SvpPPol, - pub sk1: SecretKey, - pub sk1_ppol: SvpPPol, - pub tmp_bytes: Vec, -} - -impl Context{ - pub fn new(log_n: usize, log_base2k: usize, log_q: usize, log_p: usize) -> Self{ - - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: log_n, - log_q: log_q, - log_p: log_p, - log_base2k: log_base2k, - log_scale: 20, - xe: DEFAULT_SIGMA, - xs: 1 << (log_n-1), - }; - - let params: Parameters =Parameters::new(¶ms_lit); - let module = params.module(); - - let log_q: usize = params.log_q(); - - let mut source_xs: Source = Source::new(new_seed()); - - let mut sk0: SecretKey = SecretKey::new(module); - sk0.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk0_ppol: base2k::SvpPPol = module.new_svp_ppol(); - module.svp_prepare(&mut sk0_ppol, &sk0.0); - - let mut sk1: SecretKey = SecretKey::new(module); - sk1.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk1_ppol: base2k::SvpPPol = module.new_svp_ppol(); - module.svp_prepare(&mut sk1_ppol, &sk1.0); - - let tmp_bytes: Vec = alloc_aligned(params.decrypt_rlwe_tmp_byte(log_q)| params.encrypt_rlwe_sk_tmp_bytes(log_q)); - - Context{ - params: params, - sk0: sk0, - sk0_ppol: sk0_ppol, - sk1: sk1, - sk1_ppol: sk1_ppol, - tmp_bytes: tmp_bytes, - - } - } - - pub fn encrypt_rlwe_sk0(&mut self, pt: &Plaintext, ct: &mut Ciphertext){ - - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - - encrypt_rlwe_sk( - self.params.module(), - ct.elem_mut(), - Some(pt.elem()), - &self.sk0_ppol, - &mut source_xa, - &mut source_xe, - self.params.xe(), - &mut self.tmp_bytes, - ); - } - - pub fn encrypt_rlwe_sk1(&mut self, ct: &mut Ciphertext, pt: &Plaintext){ - - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - - encrypt_rlwe_sk( - self.params.module(), - ct.elem_mut(), - Some(pt.elem()), - &self.sk1_ppol, - &mut source_xa, - &mut source_xe, - self.params.xe(), - &mut self.tmp_bytes, - ); - } - - pub fn decrypt_sk0(&mut self, pt: &mut Plaintext, ct: &Ciphertext){ - decrypt_rlwe( - self.params.module(), - pt.elem_mut(), - ct.elem(), - &self.sk0_ppol, - &mut self.tmp_bytes, - ); - } - - pub fn decrypt_sk1(&mut self, pt: &mut Plaintext, ct: &Ciphertext){ - decrypt_rlwe( - self.params.module(), - pt.elem_mut(), - ct.elem(), - &self.sk1_ppol, - &mut self.tmp_bytes, - ); - } -} \ No newline at end of file diff --git a/rlwe/src/trace.rs b/rlwe/src/trace.rs deleted file mode 100644 index 9e7feb8..0000000 --- a/rlwe/src/trace.rs +++ /dev/null @@ -1,236 +0,0 @@ -use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters}; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, assert_alignement}; -use std::collections::HashMap; - -pub fn trace_galois_elements(module: &Module) -> Vec { - let mut gal_els: Vec = Vec::new(); - (0..module.log_n()).for_each(|i| { - if i == 0 { - gal_els.push(-1); - } else { - gal_els.push(module.galois_element(1 << (i - 1))); - } - }); - gal_els -} - -impl Parameters { - pub fn trace_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { - self.automorphism_tmp_bytes(res_logq, in_logq, gct_logq) - } -} - -pub fn trace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize { - return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols) - + 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols)); -} - -pub fn trace_inplace( - module: &Module, - a: &mut Ciphertext, - start: usize, - end: usize, - b: &HashMap, - b_cols: usize, - tmp_bytes: &mut [u8], -) { - let cols: usize = a.cols(); - - let b_rows: usize; - - if let Some((_, key)) = b.iter().next() { - b_rows = key.value.rows(); - #[cfg(debug_assertions)] - { - println!("{} {}", b_cols, key.value.cols()); - assert!(b_cols <= key.value.cols()) - } - } else { - panic!("b: HashMap, is empty") - } - - #[cfg(debug_assertions)] - { - assert!(start <= end); - assert!(end <= module.n()); - assert!(tmp_bytes.len() >= trace_tmp_bytes(module, cols, cols, b_rows, b_cols)); - assert_alignement(tmp_bytes.as_ptr()); - } - - let cols: usize = std::cmp::min(b_cols, a.cols()); - - let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols)); - let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols)); - - let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft); - let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft); - let mut res_big: VecZnxBig = res_dft.as_vec_znx_big(); - - let log_base2k: usize = a.log_base2k(); - - (start..end).for_each(|i| { - a.at_mut(0).rsh(log_base2k, 1, tmp_bytes); - a.at_mut(1).rsh(log_base2k, 1, tmp_bytes); - - let p: i64; - if i == 0 { - p = -1; - } else { - p = module.galois_element(1 << (i - 1)); - } - - if let Some(key) = b.get(&p) { - module.vec_znx_dft(&mut a1_dft, a.at(1)); - - // a[0] = NORMALIZE(a[0] + AUTO(a[0] + IDFT())) - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(0), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); - module.vec_znx_big_automorphism_inplace(p, &mut res_big); - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0)); - module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut res_big, tmp_bytes); - - // a[1] = NORMALIZE(a[1] + AUTO(IDFT())) - module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(1), tmp_bytes); - module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft); - module.vec_znx_big_automorphism_inplace(p, &mut res_big); - module.vec_znx_big_add_small_inplace(&mut res_big, a.at(1)); - module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes); - } else { - panic!("b[{}] is empty", p) - } - }) -} - -#[cfg(test)] -mod test { - use super::{trace_galois_elements, trace_inplace}; - use crate::{ - automorphism::AutomorphismKey, - ciphertext::Ciphertext, - decryptor::decrypt_rlwe, - elem::ElemCommon, - encryptor::encrypt_rlwe_sk, - keys::SecretKey, - parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral}, - plaintext::Plaintext, - }; - use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, alloc_aligned}; - use sampling::source::{Source, new_seed}; - use std::collections::HashMap; - - #[test] - fn test_trace_inplace() { - let log_base2k: usize = 10; - let log_q: usize = 50; - let log_p: usize = 15; - - // Basic parameters with enough limbs to test edge cases - let params_lit: ParametersLiteral = ParametersLiteral { - backend: BACKEND::FFT64, - log_n: 12, - log_q: log_q, - log_p: log_p, - log_base2k: log_base2k, - log_scale: 20, - xe: 3.2, - xs: 1 << 11, - }; - - let params: Parameters = Parameters::new(¶ms_lit); - - let module: &Module = params.module(); - let log_q: usize = params.log_q(); - let log_qp: usize = params.log_qp(); - let gct_rows: usize = params.cols_q(); - let gct_cols: usize = params.cols_qp(); - - // scratch space - let mut tmp_bytes: Vec = alloc_aligned( - params.decrypt_rlwe_tmp_byte(log_q) - | params.encrypt_rlwe_sk_tmp_bytes(log_q) - | params.automorphism_key_new_tmp_bytes(gct_rows, log_qp) - | params.automorphism_tmp_bytes(log_q, log_q, log_qp), - ); - - // Samplers for public and private randomness - let mut source_xe: Source = Source::new(new_seed()); - let mut source_xa: Source = Source::new(new_seed()); - let mut source_xs: Source = Source::new(new_seed()); - - let mut sk: SecretKey = SecretKey::new(module); - sk.fill_ternary_hw(params.xs(), &mut source_xs); - let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol(); - module.svp_prepare(&mut sk_svp_ppol, &sk.0); - - let gal_els: Vec = trace_galois_elements(module); - - let auto_keys: HashMap = AutomorphismKey::new_many( - module, - &gal_els, - &sk, - log_base2k, - gct_rows, - log_qp, - &mut source_xa, - &mut source_xe, - DEFAULT_SIGMA, - &mut tmp_bytes, - ); - - let mut data: Vec = vec![0i64; params.n()]; - - data.iter_mut() - .enumerate() - .for_each(|(i, x)| *x = 1 + i as i64); - - let log_k: usize = 2 * log_base2k; - - let mut ct: Ciphertext = params.new_ciphertext(log_q); - let mut pt: Plaintext = params.new_plaintext(log_q); - - pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); - pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); - - pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data); - - pt.at(0).print(0, pt.cols(), 16); - - encrypt_rlwe_sk( - module, - &mut ct.elem_mut(), - Some(pt.at(0)), - &sk_svp_ppol, - &mut source_xa, - &mut source_xe, - params.xe(), - &mut tmp_bytes, - ); - - trace_inplace(module, &mut ct, 0, 4, &auto_keys, gct_cols, &mut tmp_bytes); - trace_inplace( - module, - &mut ct, - 4, - module.log_n(), - &auto_keys, - gct_cols, - &mut tmp_bytes, - ); - - // pt = dec(auto(ct)) - auto(pt) - decrypt_rlwe( - module, - pt.elem_mut(), - ct.elem(), - &sk_svp_ppol, - &mut tmp_bytes, - ); - - pt.at(0).print(0, pt.cols(), 16); - - pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data); - - println!("trace: {:?}", &data[..16]); - } -}