From 2f9a1cf6d9e493606dd80c4d10f735d1821cdda0 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 28 Apr 2025 10:33:15 +0200 Subject: [PATCH] refactoring of vec_znx --- base2k/examples/rlwe_encrypt.rs | 6 +- base2k/examples/vector_matrix_product.rs | 2 +- base2k/src/commons.rs | 227 ++++++- base2k/src/encoding.rs | 74 +-- base2k/src/lib.rs | 2 + base2k/src/mat_znx_dft.rs | 98 ++- base2k/src/sampling.rs | 22 +- base2k/src/scalar_znx_dft.rs | 2 +- base2k/src/stats.rs | 2 +- base2k/src/vec_znx.rs | 556 +--------------- base2k/src/vec_znx_big.rs | 86 ++- base2k/src/vec_znx_dft.rs | 84 ++- base2k/src/vec_znx_ops.rs | 795 +++++++++++++++++++++++ 13 files changed, 1218 insertions(+), 738 deletions(-) create mode 100644 base2k/src/vec_znx_ops.rs diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 3d53141..3661f0d 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -35,7 +35,7 @@ fn main() { module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.limbs()); + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.size()); // Applies buf_dft <- s * a module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); @@ -93,9 +93,9 @@ fn main() { // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; - res.decode_vec_i64(0, log_base2k, res.limbs() * log_base2k, &mut have); + res.decode_vec_i64(0, log_base2k, res.size() * log_base2k, &mut have); - let scale: f64 = (1 << (res.limbs() * log_base2k - log_scale)) as f64; + let scale: f64 = (1 << (res.size() * log_base2k - log_scale)) as f64; izip!(want.iter(), have.iter()) .enumerate() .for_each(|(i, (a, b))| { diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 0120f61..96a0df7 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -33,7 +33,7 @@ fn main() { let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(rows_mat, 1, limbs_mat); - (0..a.limbs()).for_each(|row_i| { + (0..a.size()).for_each(|row_i| { let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat); tmp.at_limb_mut(row_i)[1] = 1 as i64; module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf); diff --git a/base2k/src/commons.rs b/base2k/src/commons.rs index 290599d..1d7a0c9 100644 --- a/base2k/src/commons.rs +++ b/base2k/src/commons.rs @@ -1,11 +1,15 @@ -use crate::{Backend, Module}; +use crate::{Backend, Module, assert_alignement, cast_mut}; +use itertools::izip; +use std::cmp::{max, min}; pub trait ZnxInfos { /// Returns the ring degree of the polynomials. fn n(&self) -> usize; /// Returns the base two logarithm of the ring dimension of the polynomials. - fn log_n(&self) -> usize; + fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } /// Returns the number of rows. fn rows(&self) -> usize; @@ -13,21 +17,28 @@ pub trait ZnxInfos { /// Returns the number of polynomials in each row. fn cols(&self) -> usize; - /// Returns the number of limbs per polynomial. - fn limbs(&self) -> usize; + /// Returns the number of size per polynomial. + fn size(&self) -> usize; /// Returns the total number of small polynomials. - fn poly_count(&self) -> usize; + fn poly_count(&self) -> usize { + self.rows() * self.cols() * self.size() + } + + /// Returns the slice size, which is the offset between + /// two size of the same column. + fn sl(&self) -> usize { + self.n() * self.cols() + } } pub trait ZnxBase { type Scalar; - fn new(module: &Module, cols: usize, limbs: usize) -> Self; - fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; - fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self; - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize; + fn new(module: &Module, cols: usize, size: usize) -> Self; + fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self; + fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self; + fn bytes_of(module: &Module, cols: usize, size: usize) -> usize; } - pub trait ZnxLayout: ZnxInfos { type Scalar; @@ -52,7 +63,7 @@ pub trait ZnxLayout: ZnxInfos { #[cfg(debug_assertions)] { assert!(i < self.cols()); - assert!(j < self.limbs()); + assert!(j < self.size()); } let offset = self.n() * (j * self.cols() + i); unsafe { self.as_ptr().add(offset) } @@ -63,7 +74,7 @@ pub trait ZnxLayout: ZnxInfos { #[cfg(debug_assertions)] { assert!(i < self.cols()); - assert!(j < self.limbs()); + assert!(j < self.size()); } let offset = self.n() * (j * self.cols() + i); unsafe { self.as_mut_ptr().add(offset) } @@ -89,3 +100,195 @@ pub trait ZnxLayout: ZnxInfos { unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) } } } + +use std::convert::TryFrom; +use std::num::TryFromIntError; +use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; +pub trait IntegerType: + Copy + + std::fmt::Debug + + Default + + PartialEq + + PartialOrd + + Add + + Sub + + Mul + + Div + + Neg + + Shr + + Shl + + AddAssign + + TryFrom +{ + const BITS: u32; +} + +impl IntegerType for i64 { + const BITS: u32 = 64; +} + +impl IntegerType for i128 { + const BITS: u32 = 128; +} + +pub trait ZnxBasics: ZnxLayout +where + Self: Sized, + Self::Scalar: IntegerType, +{ + fn zero(&mut self) { + unsafe { + std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * size_of::()); + } + } + + fn zero_at(&mut self, i: usize, j: usize) { + unsafe { + std::ptr::write_bytes( + self.at_mut_ptr(i, j), + 0, + self.n() * size_of::(), + ); + } + } + + fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { + rsh(log_base2k, self, k, carry) + } +} + +pub fn rsh(log_base2k: usize, a: &mut V, k: usize, tmp_bytes: &mut [u8]) +where + V::Scalar: IntegerType, +{ + let n: usize = a.n(); + let size: usize = a.size(); + let cols: usize = a.cols(); + + #[cfg(debug_assertions)] + { + assert!( + tmp_bytes.len() >= rsh_tmp_bytes::(n, cols), + "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", + tmp_bytes.len() / size_of::(), + n, + size, + ); + assert_alignement(tmp_bytes.as_ptr()); + } + + let size: usize = a.size(); + let steps: usize = k / log_base2k; + + a.raw_mut().rotate_right(n * steps * cols); + (0..cols).for_each(|i| { + (0..steps).for_each(|j| { + a.zero_at(i, j); + }) + }); + + let k_rem: usize = k % log_base2k; + + if k_rem != 0 { + let carry: &mut [V::Scalar] = cast_mut(tmp_bytes); + + unsafe { + std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); + } + + let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap(); + let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap(); + let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); + + (steps..size).for_each(|i| { + izip!(carry.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| { + *xi += *ci << log_base2k_t; + *ci = get_base_k_carry(*xi, shift); + *xi = (*xi - *ci) >> k_rem_t; + }); + }) + } +} + +#[inline(always)] +fn get_base_k_carry(x: T, shift: T) -> T { + (x << shift) >> shift +} + +pub fn rsh_tmp_bytes(n: usize, cols: usize) -> usize { + n * cols * std::mem::size_of::() +} + +pub fn switch_degree(b: &mut T, a: &T) +where + ::Scalar: IntegerType, +{ + let (n_in, n_out) = (a.n(), b.n()); + let (gap_in, gap_out): (usize, usize); + + if n_in > n_out { + (gap_in, gap_out) = (n_in / n_out, 1) + } else { + (gap_in, gap_out) = (1, n_out / n_in); + b.zero(); + } + + let size: usize = min(a.size(), b.size()); + + (0..size).for_each(|i| { + izip!( + a.at_limb(i).iter().step_by(gap_in), + b.at_limb_mut(i).iter_mut().step_by(gap_out) + ) + .for_each(|(x_in, x_out)| *x_out = *x_in); + }); +} + +pub fn znx_post_process_ternary_op(c: &mut T, a: &T, b: &T) +where + ::Scalar: IntegerType, +{ + #[cfg(debug_assertions)] + { + assert_ne!(a.as_ptr(), b.as_ptr()); + assert_ne!(b.as_ptr(), c.as_ptr()); + assert_ne!(a.as_ptr(), c.as_ptr()); + } + + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + + let min_ab_cols: usize = min(a_cols, b_cols); + let max_ab_cols: usize = max(a_cols, b_cols); + + // Copies shared shared cols between (c, max(a, b)) + if a_cols != b_cols { + let mut x: &T = a; + if a_cols < b_cols { + x = b; + } + + let min_size = min(c.size(), x.size()); + (min_ab_cols..min(max_ab_cols, c_cols)).for_each(|i| { + (0..min_size).for_each(|j| { + c.at_poly_mut(i, j).copy_from_slice(x.at_poly(i, j)); + if NEGATE { + c.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + (min_size..c.size()).for_each(|j| { + c.zero_at(i, j); + }); + }); + } + + // Zeroes the cols of c > max(a, b). + if c_cols > max_ab_cols { + (max_ab_cols..c_cols).for_each(|i| { + (0..c.size()).for_each(|j| { + c.zero_at(i, j); + }) + }); + } +} diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 980dab4..8c41381 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -81,15 +81,15 @@ impl Encoding for VecZnx { } fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( - limbs <= a.limbs(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", - limbs, - a.limbs() + size <= a.size(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}", + size, + a.size() ); assert!(col_i < a.cols()); assert!(data.len() <= a.n()) @@ -99,7 +99,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, let log_k_rem: usize = log_base2k - (log_k % log_base2k); // Zeroes coefficients of the i-th column - (0..a.limbs()).for_each(|i| unsafe { + (0..a.size()).for_each(|i| unsafe { znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i)); }); @@ -107,11 +107,11 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(col_i, limbs - 1)[..data_len].copy_from_slice(&data[..data_len]); + a.at_poly_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]); } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs) + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size) .rev() .enumerate() .for_each(|(i, i_rev)| { @@ -122,8 +122,8 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, // Case where self.prec % self.k != 0. if log_k_rem != log_base2k { - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs).rev().for_each(|i| { + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size).rev().for_each(|i| { a.at_poly_mut(col_i, i)[..data_len] .iter_mut() .for_each(|x| *x <<= log_k_rem); @@ -132,7 +132,7 @@ fn encode_vec_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, } fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) { - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!( @@ -145,8 +145,8 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat } data.copy_from_slice(a.at_poly(col_i, 0)); let rem: usize = log_base2k - (log_k % log_base2k); - (1..limbs).for_each(|i| { - if i == limbs - 1 && rem != log_base2k { + (1..size).for_each(|i| { + if i == size - 1 && rem != log_base2k { let k_rem: usize = log_base2k - rem; izip!(a.at_poly(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); @@ -160,7 +160,7 @@ fn decode_vec_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, dat } fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Float]) { - let limbs: usize = a.limbs(); + let size: usize = a.size(); #[cfg(debug_assertions)] { assert!( @@ -172,20 +172,20 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo assert!(col_i < a.cols()); } - let prec: u32 = (log_base2k * limbs) as u32; + let prec: u32 = (log_base2k * size) as u32; // 2^{log_base2k} let base = Float::with_val(prec, (1 << log_base2k) as f64); // y[i] = sum x[j][i] * 2^{-log_base2k*j} - (0..limbs).for_each(|i| { + (0..size).for_each(|i| { if i == 0 { - izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { y.assign(*x); *y /= &base; }); } else { - izip!(a.at_poly(col_i, limbs - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + izip!(a.at_poly(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { *y += Float::with_val(prec, *x); *y /= &base; }); @@ -194,32 +194,32 @@ fn decode_vec_float(a: &VecZnx, col_i: usize, log_base2k: usize, data: &mut [Flo } fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) { - let limbs: usize = (log_k + log_base2k - 1) / log_base2k; + let size: usize = (log_k + log_base2k - 1) / log_base2k; #[cfg(debug_assertions)] { assert!(i < a.n()); assert!( - limbs <= a.limbs(), - "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.limbs()={}", - limbs, - a.limbs() + size <= a.size(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}", + size, + a.size() ); assert!(col_i < a.cols()); } let log_k_rem: usize = log_base2k - (log_k % log_base2k); - (0..a.limbs()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); + (0..a.size()).for_each(|j| a.at_poly_mut(col_i, j)[i] = 0); // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. // Else we decompose values base2k. if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - a.at_poly_mut(col_i, limbs - 1)[i] = value; + a.at_poly_mut(col_i, size - 1)[i] = value; } else { let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs) + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size) .rev() .enumerate() .for_each(|(j, j_rev)| { @@ -229,8 +229,8 @@ fn encode_coeff_i64(a: &mut VecZnx, col_i: usize, log_base2k: usize, log_k: usiz // Case where prec % k != 0. if log_k_rem != log_base2k { - let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); - (limbs - steps..limbs).rev().for_each(|j| { + let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k); + (size - steps..size).rev().for_each(|j| { a.at_poly_mut(col_i, j)[i] <<= log_k_rem; }) } @@ -247,7 +247,7 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i let data: &[i64] = a.raw(); let mut res: i64 = data[i]; let rem: usize = log_base2k - (log_k % log_base2k); - let slice_size: usize = a.n() * a.limbs(); + let slice_size: usize = a.n() * a.size(); (1..cols).for_each(|i| { let x = data[i * slice_size]; if i == cols - 1 && rem != log_base2k { @@ -271,9 +271,9 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let limbs: usize = 5; - let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, limbs); + let size: usize = 5; + let log_k: usize = size * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(&module, 2, size); let mut source: Source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); @@ -293,9 +293,9 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let limbs: usize = 5; - let log_k: usize = limbs * log_base2k - 5; - let mut a: VecZnx = VecZnx::new(&module, 2, limbs); + let size: usize = 5; + let log_k: usize = size * log_base2k - 5; + let mut a: VecZnx = VecZnx::new(&module, 2, size); let mut source = Source::new([0u8; 32]); let raw: &mut [i64] = a.raw_mut(); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 40df3bb..3c48319 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -11,6 +11,7 @@ pub mod stats; pub mod vec_znx; pub mod vec_znx_big; pub mod vec_znx_dft; +pub mod vec_znx_ops; pub use commons::*; pub use encoding::*; @@ -23,6 +24,7 @@ pub use stats::*; pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; +pub use vec_znx_ops::*; pub const GALOISGENERATOR: u64 = 5; pub const DEFAULTALIGN: usize = 64; diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 9466696..b40ed71 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -22,7 +22,7 @@ pub struct MatZnxDft { /// Number of cols cols: usize, /// The number of small polynomials - limbs: usize, + size: usize, _marker: PhantomData, } @@ -31,10 +31,6 @@ impl ZnxInfos for MatZnxDft { self.n } - fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - fn rows(&self) -> usize { self.rows } @@ -43,18 +39,14 @@ impl ZnxInfos for MatZnxDft { self.cols } - fn limbs(&self) -> usize { - self.limbs - } - - fn poly_count(&self) -> usize { - self.rows * self.cols * self.limbs + fn size(&self) -> usize { + self.size } } impl MatZnxDft { - fn new(module: &Module, rows: usize, cols: usize, limbs: usize) -> MatZnxDft { - let mut data: Vec = alloc_aligned::(module.bytes_of_mat_znx_dft(rows, cols, limbs)); + fn new(module: &Module, rows: usize, cols: usize, size: usize) -> MatZnxDft { + let mut data: Vec = alloc_aligned::(module.bytes_of_mat_znx_dft(rows, cols, size)); let ptr: *mut u8 = data.as_mut_ptr(); MatZnxDft:: { data: data, @@ -62,7 +54,7 @@ impl MatZnxDft { n: module.n(), rows: rows, cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } @@ -115,7 +107,7 @@ impl MatZnxDft { fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { let nrows: usize = self.rows(); - let nsize: usize = self.limbs(); + let nsize: usize = self.size(); if col == (nsize - 1) && (nsize & 1 == 1) { &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] } else { @@ -127,7 +119,7 @@ impl MatZnxDft { /// This trait implements methods for vector matrix product, /// that is, multiplying a [VecZnx] with a [VmpPMat]. pub trait MatZnxDftOps { - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize; + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; /// Allocates a new [VmpPMat] with the given number of rows and columns. /// @@ -135,7 +127,7 @@ pub trait MatZnxDftOps { /// /// * `rows`: number of rows (number of [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]). - fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft; + fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft; /// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. /// @@ -351,12 +343,12 @@ pub trait MatZnxDftOps { } impl MatZnxDftOps for Module { - fn new_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> MatZnxDft { - MatZnxDft::::new(self, rows, cols, limbs) + fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft { + MatZnxDft::::new(self, rows, cols, size) } - fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, limbs: usize) -> usize { - unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (limbs * cols) as u64) as usize } + fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize { + unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, (size * cols) as u64) as usize } } fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { @@ -367,7 +359,7 @@ impl MatZnxDftOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.len(), b.n() * b.poly_count()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -376,7 +368,7 @@ impl MatZnxDftOps for Module { b.as_mut_ptr() as *mut vmp_pmat_t, a.as_ptr(), b.rows() as u64, - (b.limbs() * b.cols()) as u64, + (b.size() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } @@ -385,8 +377,8 @@ impl MatZnxDftOps for Module { fn vmp_prepare_row(&self, b: &mut MatZnxDft, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) { #[cfg(debug_assertions)] { - assert_eq!(a.len(), b.limbs() * self.n() * b.cols()); - assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.limbs())); + assert_eq!(a.len(), b.size() * self.n() * b.cols()); + assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -396,7 +388,7 @@ impl MatZnxDftOps for Module { a.as_ptr(), row_i as u64, b.rows() as u64, - (b.limbs() * b.cols()) as u64, + (b.size() * b.cols()) as u64, tmp_bytes.as_mut_ptr(), ); } @@ -406,7 +398,7 @@ impl MatZnxDftOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.limbs(), b.limbs()); + assert_eq!(a.size(), b.size()); assert_eq!(a.cols(), b.cols()); } unsafe { @@ -416,7 +408,7 @@ impl MatZnxDftOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - (a.limbs() * a.cols()) as u64, + (a.size() * a.cols()) as u64, ); } } @@ -425,7 +417,7 @@ impl MatZnxDftOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.limbs(), b.limbs()); + assert_eq!(a.size(), b.size()); } unsafe { vmp::vmp_prepare_row_dft( @@ -434,7 +426,7 @@ impl MatZnxDftOps for Module { a.ptr as *const vec_znx_dft_t, row_i as u64, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, ); } } @@ -443,7 +435,7 @@ impl MatZnxDftOps for Module { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); - assert_eq!(a.limbs(), b.limbs()); + assert_eq!(a.size(), b.size()); } unsafe { vmp::vmp_extract_row_dft( @@ -452,7 +444,7 @@ impl MatZnxDftOps for Module { a.as_ptr() as *const vmp_pmat_t, row_i as u64, a.rows() as u64, - a.limbs() as u64, + a.size() as u64, ); } } @@ -470,7 +462,7 @@ impl MatZnxDftOps for Module { } fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -479,20 +471,20 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.limbs() as u64, + c.size() as u64, a.as_ptr(), - a.limbs() as u64, + a.size() as u64, (a.n() * a.cols()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, tmp_bytes.as_mut_ptr(), ) } } fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -501,13 +493,13 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft_add( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.limbs() as u64, + c.size() as u64, a.as_ptr(), - a.limbs() as u64, - (a.n() * a.limbs()) as u64, + a.size() as u64, + (a.n() * a.size()) as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -526,7 +518,7 @@ impl MatZnxDftOps for Module { } fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -535,12 +527,12 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft_to_dft( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.limbs() as u64, + c.size() as u64, a.ptr as *const vec_znx_dft_t, - a.limbs() as u64, + a.size() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -553,7 +545,7 @@ impl MatZnxDftOps for Module { b: &MatZnxDft, tmp_bytes: &mut [u8], ) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.limbs(), a.limbs(), b.rows(), b.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -562,19 +554,19 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft_to_dft_add( self.ptr, c.ptr as *mut vec_znx_dft_t, - c.limbs() as u64, + c.size() as u64, a.ptr as *const vec_znx_dft_t, - a.limbs() as u64, + a.size() as u64, b.as_ptr() as *const vmp_pmat_t, b.rows() as u64, - b.limbs() as u64, + b.size() as u64, tmp_bytes.as_mut_ptr(), ) } } fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &MatZnxDft, tmp_bytes: &mut [u8]) { - debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.limbs(), b.limbs(), a.rows(), a.limbs())); + debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size())); #[cfg(debug_assertions)] { assert_alignement(tmp_bytes.as_ptr()); @@ -583,12 +575,12 @@ impl MatZnxDftOps for Module { vmp::vmp_apply_dft_to_dft( self.ptr, b.ptr as *mut vec_znx_dft_t, - b.limbs() as u64, + b.size() as u64, b.ptr as *mut vec_znx_dft_t, - b.limbs() as u64, + b.size() as u64, a.as_ptr() as *const vmp_pmat_t, a.rows() as u64, - a.limbs() as u64, + a.size() as u64, tmp_bytes.as_mut_ptr(), ) } diff --git a/base2k/src/sampling.rs b/base2k/src/sampling.rs index 80d174c..a96937e 100644 --- a/base2k/src/sampling.rs +++ b/base2k/src/sampling.rs @@ -3,8 +3,8 @@ use rand_distr::{Distribution, Normal}; use sampling::source::Source; pub trait Sampling { - /// Fills the first `limbs` limbs with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source); + /// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\] + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source); /// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\]. fn add_dist_f64>( @@ -32,11 +32,11 @@ pub trait Sampling { } impl Sampling for Module { - fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, limbs: usize, source: &mut Source) { + fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, col_i: usize, size: usize, source: &mut Source) { let base2k: u64 = 1 << log_base2k; let mask: u64 = base2k - 1; let base2k_half: i64 = (base2k >> 1) as i64; - (0..limbs).for_each(|j| { + (0..size).for_each(|j| { a.at_poly_mut(col_i, j) .iter_mut() .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); @@ -114,17 +114,17 @@ mod tests { let n: usize = 4096; let module: Module = Module::::new(n); let log_base2k: usize = 17; - let limbs: usize = 5; + let size: usize = 5; let mut source: Source = Source::new([0u8; 32]); let cols: usize = 2; let zero: Vec = vec![0; n]; let one_12_sqrt: f64 = 0.28867513459481287; (0..cols).for_each(|col_i| { - let mut a: VecZnx = VecZnx::new(&module, cols, limbs); - module.fill_uniform(log_base2k, &mut a, col_i, limbs, &mut source); + let mut a: VecZnx = VecZnx::new(&module, cols, size); + module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source); (0..cols).for_each(|col_j| { if col_j != col_i { - (0..limbs).for_each(|limb_i| { + (0..size).for_each(|limb_i| { assert_eq!(a.at_poly(col_j, limb_i), zero); }) } else { @@ -146,7 +146,7 @@ mod tests { let module: Module = Module::::new(n); let log_base2k: usize = 17; let log_k: usize = 2 * 17; - let limbs: usize = 5; + let size: usize = 5; let sigma: f64 = 3.2; let bound: f64 = 6.0 * sigma; let mut source: Source = Source::new([0u8; 32]); @@ -154,11 +154,11 @@ mod tests { let zero: Vec = vec![0; n]; let k_f64: f64 = (1u64 << log_k as u64) as f64; (0..cols).for_each(|col_i| { - let mut a: VecZnx = VecZnx::new(&module, cols, limbs); + let mut a: VecZnx = VecZnx::new(&module, cols, size); module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); (0..cols).for_each(|col_j| { if col_j != col_i { - (0..limbs).for_each(|limb_i| { + (0..size).for_each(|limb_i| { assert_eq!(a.at_poly(col_j, limb_i), zero); }) } else { diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 7457ca2..cfe2f45 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -120,7 +120,7 @@ impl Scalar { VecZnx { n: self.n, cols: 1, - limbs: 1, + size: 1, data: Vec::new(), ptr: self.ptr, } diff --git a/base2k/src/stats.rs b/base2k/src/stats.rs index 7fcf7c3..4e2a512 100644 --- a/base2k/src/stats.rs +++ b/base2k/src/stats.rs @@ -10,7 +10,7 @@ pub trait Stats { impl Stats for VecZnx { fn std(&self, col_i: usize, log_base2k: usize) -> f64 { - let prec: u32 = (self.limbs() * log_base2k) as u32; + let prec: u32 = (self.size() * log_base2k) as u32; let mut data: Vec = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect(); self.decode_vec_float(col_i, log_base2k, &mut data); // std = sqrt(sum((xi - avg)^2) / n) diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 89173f0..1bb8ab3 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -1,11 +1,10 @@ use crate::Backend; use crate::ZnxBase; use crate::cast_mut; -use crate::ffi::vec_znx; use crate::ffi::znx; -use crate::{Module, ZnxInfos, ZnxLayout}; +use crate::switch_degree; +use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout}; use crate::{alloc_aligned, assert_alignement}; -use itertools::izip; use std::cmp::min; /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of @@ -26,8 +25,8 @@ pub struct VecZnx { /// The number of polynomials pub cols: usize, - /// The number of limbs per polynomial (a.k.a small polynomials). - pub limbs: usize, + /// The number of size per polynomial (a.k.a small polynomials). + pub size: usize, /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. pub data: Vec, @@ -41,10 +40,6 @@ impl ZnxInfos for VecZnx { self.n } - fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - fn rows(&self) -> usize { 1 } @@ -53,12 +48,8 @@ impl ZnxInfos for VecZnx { self.cols } - fn limbs(&self) -> usize { - self.limbs - } - - fn poly_count(&self) -> usize { - self.cols * self.limbs + fn size(&self) -> usize { + self.size } } @@ -74,6 +65,8 @@ impl ZnxLayout for VecZnx { } } +impl ZnxBasics for VecZnx {} + /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -89,28 +82,28 @@ impl ZnxBase for VecZnx { type Scalar = i64; /// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\]. - fn new(module: &Module, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, size: usize) -> Self { let n: usize = module.n(); #[cfg(debug_assertions)] { assert!(n > 0); assert!(n & (n - 1) == 0); assert!(cols > 0); - assert!(limbs > 0); + assert!(size > 0); } - let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, limbs)); + let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, size)); let ptr: *mut i64 = data.as_mut_ptr(); Self { n: n, cols: cols, - limbs: limbs, + size: size, data: data, ptr: ptr, } } - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { - module.n() * cols * limbs * size_of::() + fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + module.n() * cols * size * size_of::() } /// Returns a new struct implementing [VecZnx] with the provided data as backing array. @@ -118,14 +111,14 @@ impl ZnxBase for VecZnx { /// The struct will take ownership of buf[..[Self::bytes_of]] /// /// User must ensure that data is properly alligned and that - /// the limbs of data is equal to [Self::bytes_of]. - fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + /// the size of data is equal to [Self::bytes_of]. + fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self { let n: usize = module.n(); #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()); } unsafe { @@ -134,25 +127,25 @@ impl ZnxBase for VecZnx { Self { n: n, cols: cols, - limbs: limbs, + size: size, data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()), ptr: ptr, } } } - fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [u8]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [u8]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert!(bytes.len() >= Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert!(bytes.len() >= Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()); } Self { n: module.n(), cols: cols, - limbs: limbs, + size: size, data: Vec::new(), ptr: bytes.as_mut_ptr() as *mut i64, } @@ -173,16 +166,16 @@ impl VecZnx { if !self.borrowing() { self.data - .truncate(self.n() * self.cols() * (self.limbs() - k / log_base2k)); + .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); } - self.limbs -= k / log_base2k; + self.size -= k / log_base2k; let k_rem: usize = k % log_base2k; if k_rem != 0 { let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; - self.at_limb_mut(self.limbs() - 1) + self.at_limb_mut(self.size() - 1) .iter_mut() .for_each(|x: &mut i64| *x &= mask) } @@ -196,52 +189,22 @@ impl VecZnx { self.data.len() == 0 } - pub fn zero(&mut self) { - unsafe { znx::znx_zero_i64_ref((self.n * self.poly_count()) as u64, self.ptr) } - } - pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { normalize(log_base2k, self, carry) } - pub fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { - rsh(log_base2k, self, k, carry) - } - pub fn switch_degree(&self, a: &mut Self) { switch_degree(a, self) } // Prints the first `n` coefficients of each limb pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) } } -pub fn switch_degree(b: &mut VecZnx, a: &VecZnx) { - let (n_in, n_out) = (a.n(), b.n()); - let (gap_in, gap_out): (usize, usize); - - if n_in > n_out { - (gap_in, gap_out) = (n_in / n_out, 1) - } else { - (gap_in, gap_out) = (1, n_out / n_in); - b.zero(); - } - - let limbs: usize = min(a.limbs(), b.limbs()); - - (0..limbs).for_each(|i| { - izip!( - a.at_limb(i).iter().step_by(gap_in), - b.at_limb_mut(i).iter_mut().step_by(gap_out) - ) - .for_each(|(x_in, x_out)| *x_out = *x_in); - }); -} - -fn normalize_tmp_bytes(n: usize, limbs: usize) -> usize { - n * limbs * std::mem::size_of::() +fn normalize_tmp_bytes(n: usize, size: usize) -> usize { + n * size * std::mem::size_of::() } fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { @@ -264,7 +227,7 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { unsafe { znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); - (0..a.limbs()).rev().for_each(|i| { + (0..a.size()).rev().for_each(|i| { znx::znx_normalize( (n * cols) as u64, log_base2k as u64, @@ -276,462 +239,3 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { }); } } - -pub fn rsh_tmp_bytes(n: usize, limbs: usize) -> usize { - n * limbs * std::mem::size_of::() -} - -pub fn rsh(log_base2k: usize, a: &mut VecZnx, k: usize, tmp_bytes: &mut [u8]) { - let n: usize = a.n(); - let limbs: usize = a.limbs(); - - #[cfg(debug_assertions)] - { - assert!( - tmp_bytes.len() >= rsh_tmp_bytes(n, limbs), - "invalid carry: carry.len()/8={} < rsh_tmp_bytes({}, {})", - tmp_bytes.len() >> 3, - n, - limbs, - ); - assert_alignement(tmp_bytes.as_ptr()); - } - - let limbs: usize = a.limbs(); - let size_steps: usize = k / log_base2k; - - a.raw_mut().rotate_right(n * limbs * size_steps); - unsafe { - znx::znx_zero_i64_ref((n * limbs * size_steps) as u64, a.as_mut_ptr()); - } - - let k_rem = k % log_base2k; - - if k_rem != 0 { - let carry_i64: &mut [i64] = cast_mut(tmp_bytes); - - unsafe { - znx::znx_zero_i64_ref((n * limbs) as u64, carry_i64.as_mut_ptr()); - } - - let log_base2k: usize = log_base2k; - - (size_steps..limbs).for_each(|i| { - izip!(carry_i64.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| { - *xi += *ci << log_base2k; - *ci = get_base_k_carry(*xi, k_rem); - *xi = (*xi - *ci) >> k_rem; - }); - }) - } -} - -#[inline(always)] -fn get_base_k_carry(x: i64, k: usize) -> i64 { - (x << 64 - k) >> (64 - k) -} - -pub trait VecZnxOps { - /// Allocates a new [VecZnx]. - /// - /// # Arguments - /// - /// * `cols`: the number of polynomials. - /// * `limbs`: the number of limbs per polynomial (a.k.a small polynomials). - fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx; - - fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx; - fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [VecZnx] through [VecZnx::from_bytes]. - fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; - - fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; - - /// c <- a + b. - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); - - /// b <- b + a. - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); - - /// c <- a - b. - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); - - /// b <- a - b. - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx); - - /// b <- b - a. - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx); - - /// b <- -a. - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); - - /// b <- -b. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx); - - /// b <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); - - /// a <- a * X^k (mod X^{n} + 1) - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); - - /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); - - /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); - - /// Splits b into subrings and copies them them into a. - /// - /// # Panics - /// - /// This method requires that all [VecZnx] of b have the same ring degree - /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx); - - /// Merges the subrings a into b. - /// - /// # Panics - /// - /// This method requires that all [VecZnx] of a have the same ring degree - /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); -} - -impl VecZnxOps for Module { - fn new_vec_znx(&self, cols: usize, limbs: usize) -> VecZnx { - VecZnx::new(self, cols, limbs) - } - - fn bytes_of_vec_znx(&self, cols: usize, limbs: usize) -> usize { - VecZnx::bytes_of(self, cols, limbs) - } - - fn new_vec_znx_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes(self, cols, limbs, bytes) - } - - fn new_vec_znx_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnx { - VecZnx::from_bytes_borrow(self, cols, limbs, tmp_bytes) - } - - fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } - } - - // c <- a + b - fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(c.n(), n); - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - c.as_mut_ptr(), - c.limbs() as u64, - (n * c.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - ) - } - } - - // b <- a + b - fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_add( - self.ptr, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - ) - } - } - - // c <- a + b - fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(c.n(), n); - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - c.as_mut_ptr(), - c.limbs() as u64, - (n * c.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - ) - } - } - - // b <- a - b - fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - ) - } - } - - // b <- b - a - fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_sub( - self.ptr, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - b.as_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_negate( - self.ptr, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - } - unsafe { - vec_znx::vec_znx_negate( - self.ptr, - a.as_mut_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_rotate( - self.ptr, - k, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - } - unsafe { - vec_znx::vec_znx_rotate( - self.ptr, - k, - a.as_mut_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ) - } - } - - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. - /// - /// # Arguments - /// - /// * `a`: input. - /// * `b`: output. - /// * `k`: the power to which to map each coefficients. - /// * `a_size`: the number of a_size on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `a` is greater than `a.limbs()`. - fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - assert_eq!(b.n(), n); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - b.as_mut_ptr(), - b.limbs() as u64, - (n * b.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ); - } - } - - /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. - /// - /// # Arguments - /// - /// * `a`: input and output. - /// * `k`: the power to which to map each coefficients. - /// * `a_size`: the number of size on which to apply the mapping. - /// - /// # Panics - /// - /// The method will panic if the argument `size` is greater than `self.limbs()`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { - let n: usize = self.n(); - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), n); - } - unsafe { - vec_znx::vec_znx_automorphism( - self.ptr, - k, - a.as_mut_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - a.as_ptr(), - a.limbs() as u64, - (n * a.cols()) as u64, - ); - } - } - - fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx) { - let (n_in, n_out) = (a.n(), b[0].n()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - b[1..].iter().for_each(|bi| { - debug_assert_eq!( - bi.n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) - }); - - b.iter_mut().enumerate().for_each(|(i, bi)| { - if i == 0 { - switch_degree(bi, a); - self.vec_znx_rotate(-1, buf, a); - } else { - switch_degree(bi, buf); - self.vec_znx_rotate_inplace(-1, buf); - } - }) - } - - fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec) { - let (n_in, n_out) = (b.n(), a[0].n()); - - debug_assert!( - n_out < n_in, - "invalid a: output ring degree should be smaller" - ); - a[1..].iter().for_each(|ai| { - debug_assert_eq!( - ai.n(), - n_out, - "invalid input a: all VecZnx must have the same degree" - ) - }); - - a.iter().enumerate().for_each(|(_, ai)| { - switch_degree(b, ai); - self.vec_znx_rotate_inplace(-1, b); - }); - - self.vec_znx_rotate_inplace(a.len() as i64, b); - } -} diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 8c67a8d..7f647da 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -7,43 +7,43 @@ pub struct VecZnxBig { pub ptr: *mut u8, pub n: usize, pub cols: usize, - pub limbs: usize, + pub size: usize, pub _marker: PhantomData, } impl ZnxBase for VecZnxBig { type Scalar = u8; - fn new(module: &Module, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, size: usize) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); + assert!(size > 0); } - let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, limbs)); + let mut data: Vec = alloc_aligned::(Self::bytes_of(module, cols, size)); let ptr: *mut Self::Scalar = data.as_mut_ptr(); Self { data: data, ptr: ptr, n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, limbs as u64) as usize * cols } + fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } } /// Returns a new [VecZnxBig] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_big]. - fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { + fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()) }; unsafe { @@ -52,18 +52,18 @@ impl ZnxBase for VecZnxBig { ptr: bytes.as_mut_ptr(), n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } } - fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()); } Self { @@ -71,17 +71,13 @@ impl ZnxBase for VecZnxBig { ptr: bytes.as_mut_ptr(), n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } } impl ZnxInfos for VecZnxBig { - fn log_n(&self) -> usize { - (usize::BITS - (self.n - 1).leading_zeros()) as _ - } - fn n(&self) -> usize { self.n } @@ -94,12 +90,8 @@ impl ZnxInfos for VecZnxBig { 1 } - fn limbs(&self) -> usize { - self.limbs - } - - fn poly_count(&self) -> usize { - self.cols * self.limbs + fn size(&self) -> usize { + self.size } } @@ -117,13 +109,13 @@ impl ZnxLayout for VecZnxBig { impl VecZnxBig { pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. - fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig; + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -132,12 +124,12 @@ pub trait VecZnxBigOps { /// # Arguments /// /// * `cols`: the number of polynomials.. - /// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. + /// * `size`: the number of size (a.k.a small polynomials) per polynomial. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig; /// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// @@ -146,25 +138,25 @@ pub trait VecZnxBigOps { /// # Arguments /// /// * `cols`: the number of polynomials.. - /// * `limbs`: the number of limbs (a.k.a small polynomials) per polynomial. + /// * `size`: the number of size (a.k.a small polynomials) per polynomial. /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big]. /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxBig] through [VecZnxBig::from_bytes]. - fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize; + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; /// b[VecZnxBig] <- b[VecZnxBig] - a[VecZnx] /// /// # Behavior /// - /// [VecZnxBig] (3 cols and 4 limbs) + /// [VecZnxBig] (3 cols and 4 size) /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] /// - - /// [VecZnx] (2 cols and 3 limbs) + /// [VecZnx] (2 cols and 3 size) /// [d0, e0] [d1, e1] [d2, e2] /// = /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] @@ -203,26 +195,26 @@ pub trait VecZnxBigOps { } impl VecZnxBigOps for Module { - fn new_vec_znx_big(&self, cols: usize, limbs: usize) -> VecZnxBig { - VecZnxBig::new(self, cols, limbs) + fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig { + VecZnxBig::new(self, cols, size) } - fn new_vec_znx_big_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes(self, cols, limbs, bytes) + fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes(self, cols, size, bytes) } - fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { - VecZnxBig::from_bytes_borrow(self, cols, limbs, tmp_bytes) + fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig { + VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes) } - fn bytes_of_vec_znx_big(&self, cols: usize, limbs: usize) -> usize { - VecZnxBig::bytes_of(self, cols, limbs) + fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { + VecZnxBig::bytes_of(self, cols, size) } - /// [VecZnxBig] (3 cols and 4 limbs) + /// [VecZnxBig] (3 cols and 4 size) /// [a0, b0, c0] [a1, b1, c1] [a2, b2, c2] [a3, b3, c3] /// - - /// [VecZnx] (2 cols and 3 limbs) + /// [VecZnx] (2 cols and 3 size) /// [d0, e0] [d1, e1] [d2, e2] /// = /// [a0-d0, b0-e0, c0] [a1-d1, b1-e1, c1] [a2-d2, b2-e2, c2] [a3, b3, c3] @@ -306,10 +298,10 @@ impl VecZnxBigOps for Module { self.ptr, log_base2k as u64, b.as_mut_ptr(), - b.limbs() as u64, + b.size() as u64, b.n() as u64, a.ptr as *mut vec_znx_big_t, - a.limbs() as u64, + a.size() as u64, tmp_bytes.as_mut_ptr(), ) } @@ -344,7 +336,7 @@ impl VecZnxBigOps for Module { self.ptr, log_base2k as u64, res.as_mut_ptr(), - res.limbs() as u64, + res.size() as u64, res.n() as u64, a.ptr as *mut vec_znx_big_t, a_range_begin as u64, diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 7724710..d9c9e60 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -10,44 +10,44 @@ pub struct VecZnxDft { pub ptr: *mut u8, pub n: usize, pub cols: usize, - pub limbs: usize, + pub size: usize, pub _marker: PhantomData, } impl ZnxBase for VecZnxDft { type Scalar = u8; - fn new(module: &Module, cols: usize, limbs: usize) -> Self { + fn new(module: &Module, cols: usize, size: usize) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); + assert!(size > 0); } - let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, limbs)); + let mut data: Vec = alloc_aligned(Self::bytes_of(module, cols, size)); let ptr: *mut Self::Scalar = data.as_mut_ptr(); Self { data: data, ptr: ptr, n: module.n(), - limbs: limbs, + size: size, cols: cols, _marker: PhantomData, } } - fn bytes_of(module: &Module, cols: usize, limbs: usize) -> usize { - unsafe { bytes_of_vec_znx_dft(module.ptr, limbs as u64) as usize * cols } + fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { + unsafe { bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } } /// Returns a new [VecZnxDft] with the provided data as backing array. /// User must ensure that data is properly alligned and that /// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. - fn from_bytes(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { + fn from_bytes(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()) } unsafe { @@ -56,18 +56,18 @@ impl ZnxBase for VecZnxDft { ptr: bytes.as_mut_ptr(), n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } } - fn from_bytes_borrow(module: &Module, cols: usize, limbs: usize, bytes: &mut [Self::Scalar]) -> Self { + fn from_bytes_borrow(module: &Module, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self { #[cfg(debug_assertions)] { assert!(cols > 0); - assert!(limbs > 0); - assert_eq!(bytes.len(), Self::bytes_of(module, cols, limbs)); + assert!(size > 0); + assert_eq!(bytes.len(), Self::bytes_of(module, cols, size)); assert_alignement(bytes.as_ptr()); } Self { @@ -75,7 +75,7 @@ impl ZnxBase for VecZnxDft { ptr: bytes.as_mut_ptr(), n: module.n(), cols: cols, - limbs: limbs, + size: size, _marker: PhantomData, } } @@ -91,7 +91,7 @@ impl VecZnxDft { ptr: self.ptr, n: self.n, cols: self.cols, - limbs: self.limbs, + size: self.size, _marker: PhantomData, } } @@ -102,10 +102,6 @@ impl ZnxInfos for VecZnxDft { self.n } - fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - fn rows(&self) -> usize { 1 } @@ -114,12 +110,8 @@ impl ZnxInfos for VecZnxDft { self.cols } - fn limbs(&self) -> usize { - self.limbs - } - - fn poly_count(&self) -> usize { - self.cols * self.limbs + fn size(&self) -> usize { + self.size } } @@ -137,13 +129,13 @@ impl ZnxLayout for VecZnxDft { impl VecZnxDft { pub fn print(&self, n: usize) { - (0..self.limbs()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); + (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); } } pub trait VecZnxDftOps { /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. - fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft; + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -156,7 +148,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -169,7 +161,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, bytes: &mut [u8]) -> VecZnxDft; + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// @@ -180,7 +172,7 @@ pub trait VecZnxDftOps { /// /// # Panics /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize; + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize; /// Returns the minimum number of bytes necessary to allocate /// a new [VecZnxDft] through [VecZnxDft::from_bytes]. @@ -201,20 +193,20 @@ pub trait VecZnxDftOps { } impl VecZnxDftOps for Module { - fn new_vec_znx_dft(&self, cols: usize, limbs: usize) -> VecZnxDft { - VecZnxDft::::new(&self, cols, limbs) + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft { + VecZnxDft::::new(&self, cols, size) } - fn new_vec_znx_dft_from_bytes(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes(self, cols, limbs, tmp_bytes) + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes(self, cols, size, tmp_bytes) } - fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, limbs: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { - VecZnxDft::from_bytes_borrow(self, cols, limbs, tmp_bytes) + fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft { + VecZnxDft::from_bytes_borrow(self, cols, size, tmp_bytes) } - fn bytes_of_vec_znx_dft(&self, cols: usize, limbs: usize) -> usize { - VecZnxDft::bytes_of(&self, cols, limbs) + fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { + VecZnxDft::bytes_of(&self, cols, size) } fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) { @@ -242,9 +234,9 @@ impl VecZnxDftOps for Module { vec_znx_dft::vec_znx_dft( self.ptr, b.ptr as *mut vec_znx_dft_t, - b.limbs() as u64, + b.size() as u64, a.as_ptr(), - a.limbs() as u64, + a.size() as u64, (a.n() * a.cols()) as u64, ) } @@ -329,14 +321,14 @@ mod tests { let n: usize = 8; let module: Module = Module::::new(n); - let limbs: usize = 2; + let size: usize = 2; let log_base2k: usize = 17; - let mut a: VecZnx = module.new_vec_znx(1, limbs); - let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); - let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, limbs); + let mut a: VecZnx = module.new_vec_znx(1, size); + let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, size); + let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, size); let mut source: Source = Source::new([0u8; 32]); - module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); + module.fill_uniform(log_base2k, &mut a, 0, size, &mut source); let mut tmp_bytes: Vec = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes()); diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs new file mode 100644 index 0000000..7afcc9a --- /dev/null +++ b/base2k/src/vec_znx_ops.rs @@ -0,0 +1,795 @@ +use crate::ffi::module::MODULE; +use crate::ffi::vec_znx; +use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, switch_degree, znx_post_process_ternary_op}; +use std::cmp::min; +pub trait VecZnxOps { + /// Allocates a new [VecZnx]. + /// + /// # Arguments + /// + /// * `cols`: the number of polynomials. + /// * `size`: the number of size per polynomial (a.k.a small polynomials). + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx; + + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx; + fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; + + /// Returns the minimum number of bytes necessary to allocate + /// a new [VecZnx] through [VecZnx::from_bytes]. + fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize; + + fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; + + /// c <- a + b. + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + + /// b <- b + a. + fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx); + + /// c <- a - b. + fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); + + /// b <- a - b. + fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx); + + /// b <- b - a. + fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx); + + /// b <- -a. + fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx); + + /// b <- -b. + fn vec_znx_negate_inplace(&self, a: &mut VecZnx); + + /// b <- a * X^k (mod X^{n} + 1) + fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + + /// a <- a * X^k (mod X^{n} + 1) + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx); + + /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) + fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx); + + /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx); + + /// Splits b into subrings and copies them them into a. + /// + /// # Panics + /// + /// This method requires that all [VecZnx] of b have the same ring degree + /// and that b.n() * b.len() <= a.n() + fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx); + + /// Merges the subrings a into b. + /// + /// # Panics + /// + /// This method requires that all [VecZnx] of a have the same ring degree + /// and that a.n() * a.len() <= b.n() + fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec); +} + +impl VecZnxOps for Module { + fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx { + VecZnx::new(self, cols, size) + } + + fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { + VecZnx::bytes_of(self, cols, size) + } + + fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { + VecZnx::from_bytes(self, cols, size, bytes) + } + + fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx { + VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes) + } + + fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } + } + + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_add, + ); + vec_znx_apply_binary_op::(self, c, a, b, op); + } + + fn vec_znx_add_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnx = b as *mut VecZnx; + Self::vec_znx_add(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_sub(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { + let op = ffi_ternary_op_factory( + self.ptr, + c.size(), + c.sl(), + a.size(), + a.sl(), + b.size(), + b.sl(), + vec_znx::vec_znx_sub, + ); + vec_znx_apply_binary_op::(self, c, a, b, op); + } + + fn vec_znx_sub_ab_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnx = b as *mut VecZnx; + Self::vec_znx_sub(self, &mut *b_ptr, a, &*b_ptr); + } + } + + fn vec_znx_sub_ba_inplace(&self, b: &mut VecZnx, a: &VecZnx) { + unsafe { + let b_ptr: *mut VecZnx = b as *mut VecZnx; + Self::vec_znx_sub(self, &mut *b_ptr, &*b_ptr, a); + } + } + + fn vec_znx_negate(&self, b: &mut VecZnx, a: &VecZnx) { + let op = ffi_binary_op_factory_type_0( + self.ptr, + b.size(), + b.sl(), + a.size(), + a.sl(), + vec_znx::vec_znx_negate, + ); + vec_znx_apply_unary_op::(self, b, a, op); + } + + fn vec_znx_negate_inplace(&self, a: &mut VecZnx) { + unsafe { + let a_ptr: *mut VecZnx = a as *mut VecZnx; + Self::vec_znx_negate(self, &mut *a_ptr, &*a_ptr); + } + } + + fn vec_znx_rotate(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { + let op = ffi_binary_op_factory_type_1( + self.ptr, + k, + b.size(), + b.sl(), + a.size(), + a.sl(), + vec_znx::vec_znx_rotate, + ); + vec_znx_apply_unary_op::(self, b, a, op); + } + + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx) { + unsafe { + let a_ptr: *mut VecZnx = a as *mut VecZnx; + Self::vec_znx_rotate(self, k, &mut *a_ptr, &*a_ptr); + } + } + + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. + /// + /// # Arguments + /// + /// * `a`: input. + /// * `b`: output. + /// * `k`: the power to which to map each coefficients. + /// * `a_size`: the number of a_size on which to apply the mapping. + /// + /// # Panics + /// + /// The method will panic if the argument `a` is greater than `a.size()`. + fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) { + let op = ffi_binary_op_factory_type_1( + self.ptr, + k, + b.size(), + b.sl(), + a.size(), + a.sl(), + vec_znx::vec_znx_automorphism, + ); + vec_znx_apply_unary_op::(self, b, a, op); + } + + /// Maps X^i to X^{ik} mod X^{n}+1. The mapping is applied independently on each size. + /// + /// # Arguments + /// + /// * `a`: input and output. + /// * `k`: the power to which to map each coefficients. + /// * `a_size`: the number of size on which to apply the mapping. + /// + /// # Panics + /// + /// The method will panic if the argument `size` is greater than `self.size()`. + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) { + unsafe { + let a_ptr: *mut VecZnx = a as *mut VecZnx; + Self::vec_znx_automorphism(self, k, &mut *a_ptr, &*a_ptr); + } + } + + fn vec_znx_split(&self, b: &mut Vec, a: &VecZnx, buf: &mut VecZnx) { + let (n_in, n_out) = (a.n(), b[0].n()); + + debug_assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + b[1..].iter().for_each(|bi| { + debug_assert_eq!( + bi.n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + b.iter_mut().enumerate().for_each(|(i, bi)| { + if i == 0 { + switch_degree(bi, a); + self.vec_znx_rotate(-1, buf, a); + } else { + switch_degree(bi, buf); + self.vec_znx_rotate_inplace(-1, buf); + } + }) + } + + fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec) { + let (n_in, n_out) = (b.n(), a[0].n()); + + debug_assert!( + n_out < n_in, + "invalid a: output ring degree should be smaller" + ); + a[1..].iter().for_each(|ai| { + debug_assert_eq!( + ai.n(), + n_out, + "invalid input a: all VecZnx must have the same degree" + ) + }); + + a.iter().enumerate().for_each(|(_, ai)| { + switch_degree(b, ai); + self.vec_znx_rotate_inplace(-1, b); + }); + + self.vec_znx_rotate_inplace(a.len() as i64, b); + } +} + +fn ffi_ternary_op_factory( + module_ptr: *const MODULE, + c_size: usize, + c_sl: usize, + a_size: usize, + a_sl: usize, + b_size: usize, + b_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64, *const i64, u64, u64), +) -> impl Fn(&mut [i64], &[i64], &[i64]) { + move |cv: &mut [i64], av: &[i64], bv: &[i64]| unsafe { + op_fn( + module_ptr, + cv.as_mut_ptr(), + c_size as u64, + c_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + bv.as_ptr(), + b_size as u64, + b_sl as u64, + ) + } +} + +fn ffi_binary_op_factory_type_0( + module_ptr: *const MODULE, + b_size: usize, + b_sl: usize, + a_size: usize, + a_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, *mut i64, u64, u64, *const i64, u64, u64), +) -> impl Fn(&mut [i64], &[i64]) { + move |bv: &mut [i64], av: &[i64]| unsafe { + op_fn( + module_ptr, + bv.as_mut_ptr(), + b_size as u64, + b_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + ) + } +} + +fn ffi_binary_op_factory_type_1( + module_ptr: *const MODULE, + k: i64, + b_size: usize, + b_sl: usize, + a_size: usize, + a_sl: usize, + op_fn: unsafe extern "C" fn(*const MODULE, i64, *mut i64, u64, u64, *const i64, u64, u64), +) -> impl Fn(&mut [i64], &[i64]) { + move |bv: &mut [i64], av: &[i64]| unsafe { + op_fn( + module_ptr, + k, + bv.as_mut_ptr(), + b_size as u64, + b_sl as u64, + av.as_ptr(), + a_size as u64, + a_sl as u64, + ) + } +} + +#[inline(always)] +pub fn vec_znx_apply_binary_op( + module: &Module, + c: &mut VecZnx, + a: &VecZnx, + b: &VecZnx, + op: impl Fn(&mut [i64], &[i64], &[i64]), +) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(c.n(), module.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let c_cols: usize = c.cols(); + + let min_ab_cols: usize = min(a_cols, b_cols); + let min_cols: usize = min(c_cols, min_ab_cols); + + // Applies over shared cols between (a, b, c) + (0..min_cols).for_each(|i| op(c.at_poly_mut(i, 0), a.at_poly(i, 0), b.at_poly(i, 0))); + // Copies/Negates/Zeroes the remaining cols if op is not inplace. + if c.as_ptr() != a.as_ptr() && c.as_ptr() != b.as_ptr() { + znx_post_process_ternary_op::(c, a, b); + } +} + +#[inline(always)] +pub fn vec_znx_apply_unary_op(module: &Module, b: &mut VecZnx, a: &VecZnx, op: impl Fn(&mut [i64], &[i64])) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + } + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + // Applies over the shared cols between (a, b) + (0..min_cols).for_each(|i| op(b.at_poly_mut(i, 0), a.at_poly(i, 0))); + // Zeroes the remaining cols of b. + (min_cols..b_cols).for_each(|i| (0..b.size()).for_each(|j| b.zero_at(i, j))); +} + +#[cfg(test)] +mod tests { + use crate::{ + Backend, FFT64, Module, Sampling, VecZnx, VecZnxOps, ZnxBasics, ZnxInfos, ZnxLayout, ffi::vec_znx, + znx_post_process_ternary_op, + }; + use itertools::izip; + use sampling::source::Source; + use std::cmp::min; + + #[test] + fn vec_znx_add() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| { + izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi + *ai); + }; + test_binary_op::( + &module, + &|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_add(c, a, b), + op, + ); + } + + #[test] + fn vec_znx_add_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |bv: &mut [i64], av: &[i64]| { + izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi + *ai); + }; + test_binary_op_inplace::( + &module, + &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_add_inplace(b, a), + op, + ); + } + + #[test] + fn vec_znx_sub() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |cv: &mut [i64], av: &[i64], bv: &[i64]| { + izip!(cv.iter_mut(), bv.iter(), av.iter()).for_each(|(ci, bi, ai)| *ci = *bi - *ai); + }; + test_binary_op::( + &module, + &|c: &mut VecZnx, a: &VecZnx, b: &VecZnx| module.vec_znx_sub(c, a, b), + op, + ); + } + + #[test] + fn vec_znx_sub_ab_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |bv: &mut [i64], av: &[i64]| { + izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *ai - *bi); + }; + test_binary_op_inplace::( + &module, + &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ab_inplace(b, a), + op, + ); + } + + #[test] + fn vec_znx_sub_ba_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |bv: &mut [i64], av: &[i64]| { + izip!(bv.iter_mut(), av.iter()).for_each(|(bi, ai)| *bi = *bi - *ai); + }; + test_binary_op_inplace::( + &module, + &|b: &mut VecZnx, a: &VecZnx| module.vec_znx_sub_ba_inplace(b, a), + op, + ); + } + + #[test] + fn vec_znx_negate() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |b: &mut [i64], a: &[i64]| { + izip!(b.iter_mut(), a.iter()).for_each(|(bi, ai)| *bi = -*ai); + }; + test_unary_op( + &module, + |b: &mut VecZnx, a: &VecZnx| module.vec_znx_negate(b, a), + op, + ) + } + + #[test] + fn vec_znx_negate_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let op = |a: &mut [i64]| a.iter_mut().for_each(|xi| *xi = -*xi); + test_unary_op_inplace( + &module, + |a: &mut VecZnx| module.vec_znx_negate_inplace(a), + op, + ) + } + + #[test] + fn vec_znx_rotate() { + let n: usize = 8; + let module: Module = Module::::new(n); + let k: i64 = 53; + let op = |b: &mut [i64], a: &[i64]| { + assert_eq!(b.len(), a.len()); + b.copy_from_slice(a); + + let mut k_mod2n: i64 = k % (2 * n as i64); + if k_mod2n < 0 { + k_mod2n += 2 * n as i64; + } + let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1; + let k_modn: i64 = k_mod2n % (n as i64); + + b.rotate_right(k_modn as usize); + b[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x); + + if sign == 1 { + b.iter_mut().for_each(|x| *x = -*x); + } + }; + test_unary_op( + &module, + |b: &mut VecZnx, a: &VecZnx| module.vec_znx_rotate(k, b, a), + op, + ) + } + + #[test] + fn vec_znx_rotate_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let k: i64 = 53; + let rot = |a: &mut [i64]| { + let mut k_mod2n: i64 = k % (2 * n as i64); + if k_mod2n < 0 { + k_mod2n += 2 * n as i64; + } + let sign: i64 = (k_mod2n.abs() / (n as i64)) & 1; + let k_modn: i64 = k_mod2n % (n as i64); + + a.rotate_right(k_modn as usize); + a[0..k_modn as usize].iter_mut().for_each(|x| *x = -*x); + + if sign == 1 { + a.iter_mut().for_each(|x| *x = -*x); + } + }; + test_unary_op_inplace( + &module, + |a: &mut VecZnx| module.vec_znx_rotate_inplace(k, a), + rot, + ) + } + + #[test] + fn vec_znx_automorphism() { + let n: usize = 8; + let module: Module = Module::::new(n); + let k: i64 = -5; + let op = |b: &mut [i64], a: &[i64]| { + assert_eq!(b.len(), a.len()); + unsafe { + vec_znx::vec_znx_automorphism( + module.ptr, + k, + b.as_mut_ptr(), + 1u64, + n as u64, + a.as_ptr(), + 1u64, + n as u64, + ); + } + }; + test_unary_op( + &module, + |b: &mut VecZnx, a: &VecZnx| module.vec_znx_automorphism(k, b, a), + op, + ) + } + + #[test] + fn vec_znx_automorphism_inplace() { + let n: usize = 8; + let module: Module = Module::::new(n); + let k: i64 = -5; + let op = |a: &mut [i64]| unsafe { + vec_znx::vec_znx_automorphism( + module.ptr, + k, + a.as_mut_ptr(), + 1u64, + n as u64, + a.as_ptr(), + 1u64, + n as u64, + ); + }; + test_unary_op_inplace( + &module, + |a: &mut VecZnx| module.vec_znx_automorphism_inplace(k, a), + op, + ) + } + + fn test_binary_op( + module: &Module, + func_have: impl Fn(&mut VecZnx, &VecZnx, &VecZnx), + func_want: impl Fn(&mut [i64], &[i64], &[i64]), + ) { + let a_size: usize = 3; + let b_size: usize = 4; + let c_size: usize = 5; + let mut source: Source = Source::new([0u8; 32]); + + [1usize, 2, 3].iter().for_each(|a_cols| { + [1usize, 2, 3].iter().for_each(|b_cols| { + [1usize, 2, 3].iter().for_each(|c_cols| { + let min_ab_cols: usize = min(*a_cols, *b_cols); + let min_cols: usize = min(*c_cols, min_ab_cols); + let min_size: usize = min(c_size, min(a_size, b_size)); + + let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); + (0..*a_cols).for_each(|i| { + module.fill_uniform(3, &mut a, i, a_size, &mut source); + }); + + let mut b: VecZnx = module.new_vec_znx(*b_cols, b_size); + (0..*b_cols).for_each(|i| { + module.fill_uniform(3, &mut b, i, b_size, &mut source); + }); + + let mut c_have: VecZnx = module.new_vec_znx(*c_cols, c_size); + (0..c_have.cols()).for_each(|i| { + module.fill_uniform(3, &mut c_have, i, c_size, &mut source); + }); + + func_have(&mut c_have, &a, &b); + + let mut c_want: VecZnx = module.new_vec_znx(*c_cols, c_size); + + // Adds with the minimum matching columns + (0..min_cols).for_each(|i| { + // Adds with th eminimum matching size + (0..min_size).for_each(|j| { + func_want(c_want.at_poly_mut(i, j), b.at_poly(i, j), a.at_poly(i, j)); + }); + + if a_size > b_size { + // Copies remaining size of lh if lh.size() > rh.size() + (min_size..a_size).for_each(|j| { + izip!(c_want.at_poly_mut(i, j).iter_mut(), a.at_poly(i, j).iter()).for_each(|(ci, ai)| *ci = *ai); + if NEGATE { + c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + } else { + // Copies the remaining size of rh if the are greater + (min_size..b_size).for_each(|j| { + izip!(c_want.at_poly_mut(i, j).iter_mut(), b.at_poly(i, j).iter()).for_each(|(ci, bi)| *ci = *bi); + if NEGATE { + c_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + } + }); + } + }); + + znx_post_process_ternary_op::<_, NEGATE>(&mut c_want, &a, &b); + + assert_eq!(c_have.raw(), c_want.raw()); + }); + }); + }); + } + + fn test_binary_op_inplace( + module: &Module, + func_have: impl Fn(&mut VecZnx, &VecZnx), + func_want: impl Fn(&mut [i64], &[i64]), + ) { + let a_size: usize = 3; + let b_size: usize = 5; + let mut source = Source::new([0u8; 32]); + + [1usize, 2, 3].iter().for_each(|a_cols| { + [1usize, 2, 3].iter().for_each(|b_cols| { + let min_cols: usize = min(*b_cols, *a_cols); + let min_size: usize = min(b_size, a_size); + + let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); + (0..*a_cols).for_each(|i| { + module.fill_uniform(3, &mut a, i, a_size, &mut source); + }); + + let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); + (0..*b_cols).for_each(|i| { + module.fill_uniform(3, &mut b_have, i, b_size, &mut source); + }); + + let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); + b_want.raw_mut().copy_from_slice(b_have.raw()); + + func_have(&mut b_have, &a); + + // Applies with the minimum matching columns + (0..min_cols).for_each(|i| { + // Adds with th eminimum matching size + (0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j))); + if NEGATE { + (min_size..b_size).for_each(|j| { + b_want.at_poly_mut(i, j).iter_mut().for_each(|x| *x = -*x); + }); + } + }); + + assert_eq!(b_have.raw(), b_want.raw()); + }); + }); + } + + fn test_unary_op( + module: &Module, + func_have: impl Fn(&mut VecZnx, &VecZnx), + func_want: impl Fn(&mut [i64], &[i64]), + ) { + let a_size: usize = 3; + let b_size: usize = 5; + let mut source = Source::new([0u8; 32]); + + [1usize, 2, 3].iter().for_each(|a_cols| { + [1usize, 2, 3].iter().for_each(|b_cols| { + let min_cols: usize = min(*b_cols, *a_cols); + let min_size: usize = min(b_size, a_size); + + let mut a: VecZnx = module.new_vec_znx(*a_cols, a_size); + (0..a.cols()).for_each(|i| { + module.fill_uniform(3, &mut a, i, a_size, &mut source); + }); + + let mut b_have: VecZnx = module.new_vec_znx(*b_cols, b_size); + (0..b_have.cols()).for_each(|i| { + module.fill_uniform(3, &mut b_have, i, b_size, &mut source); + }); + + let mut b_want: VecZnx = module.new_vec_znx(*b_cols, b_size); + + func_have(&mut b_have, &a); + + // Applies on the minimum matching columns + (0..min_cols).for_each(|i| { + // Applies on the minimum matching size + (0..min_size).for_each(|j| func_want(b_want.at_poly_mut(i, j), a.at_poly(i, j))); + + // Zeroes the unmatching size + (min_size..b_size).for_each(|j| { + b_want.zero_at(i, j); + }) + }); + + // Zeroes the unmatching columns + (min_cols..*b_cols).for_each(|i| { + (0..b_size).for_each(|j| { + b_want.zero_at(i, j); + }) + }); + + assert_eq!(b_have.raw(), b_want.raw()); + }); + }); + } + + fn test_unary_op_inplace(module: &Module, func_have: impl Fn(&mut VecZnx), func_want: impl Fn(&mut [i64])) { + let a_size: usize = 3; + let mut source = Source::new([0u8; 32]); + [1usize, 2, 3].iter().for_each(|a_cols| { + let mut a_have: VecZnx = module.new_vec_znx(*a_cols, a_size); + (0..*a_cols).for_each(|i| { + module.fill_uniform(3, &mut a_have, i, a_size, &mut source); + }); + + let mut a_want: VecZnx = module.new_vec_znx(*a_cols, a_size); + a_have.raw_mut().copy_from_slice(a_want.raw()); + + func_have(&mut a_have); + + // Applies on the minimum matching columns + (0..*a_cols).for_each(|i| { + // Applies on the minimum matching size + (0..a_size).for_each(|j| func_want(a_want.at_poly_mut(i, j))); + }); + + assert_eq!(a_have.raw(), a_want.raw()); + }); + } +}