diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 395fdf6..ee2bd02 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -13,7 +13,7 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(2)); + let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -28,69 +28,95 @@ fn main() { // s_ppol <- DFT(s) module.svp_prepare(&mut s_ppol, &s); - // ct = (c0, c1) - let mut ct: VecZnx = module.new_vec_znx(2, ct_size); + // Allocates a VecZnx with two columns: ct=(0, 0) + let mut ct: VecZnx = module.new_vec_znx( + 2, // Number of columns + ct_size, // Number of small poly per column + ); - // Fill c1 with random values + // Fill the second column with random values: ct = (0, a) module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, ct.size()); - - // Applies buf_dft <- s * c1 - module.svp_apply_dft( - &mut buf_dft, // DFT(c1 * s) - &s_ppol, - &ct, - 1, // c1 + let mut buf_dft: VecZnxDft = module.new_vec_znx_dft( + 1, // Number of columns + ct.size(), // Number of polynomials per column ); - // Alias scratch space (VecZnxDftis always at least as big as VecZnxBig) + // Applies DFT(ct[1]) * DFT(s) + module.svp_apply_dft( + &mut buf_dft, // DFT(ct[1] * s) + &s_ppol, // DFT(s) + &ct, + 1, // Selects the second column of ct + ); + + // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); - // BIG(c1 * s) <- IDFT(DFT(c1 * s)) (not normalized) + // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - // m <- (0) - let mut m: VecZnx = module.new_vec_znx(1, msg_size); + // Creates a plaintext: VecZnx with 1 column + let mut m: VecZnx = module.new_vec_znx( + 1, // Number of columns + msg_size, // Number of small polynomials + ); let mut want: Vec = vec![0; n]; want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); m.normalize(log_base2k, &mut carry); - // m - BIG(c1 * s) - module.vec_znx_big_sub_small_ab_inplace(&mut buf_big, &m); + // m - BIG(ct[1] * s) + module.vec_znx_big_sub_small_a_inplace( + &mut buf_big, + 0, // Selects the first column of the receiver + &m, + 0, // Selects the first column of the message + ); - // c0 <- m - BIG(c1 * s) - module.vec_znx_big_normalize(log_base2k, &mut ct, &buf_big, &mut carry); + // Normalizes back to VecZnx + // ct[0] <- m - BIG(c1 * s) + module.vec_znx_big_normalize( + log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) + &buf_big, 0, // Selects the first column of buf_big + &mut carry, + ); - ct.print(ct.sl()); - - // (c0 + e, c1) + // Add noise to ct[0] + // ct[0] <- ct[0] + e module.add_normal( log_base2k, &mut ct, - 0, // c0 - log_base2k * ct_size, + 0, // Selects the first column of ct (ct[0]) + log_base2k * ct_size, // Scaling of the noise: 2^{-log_base2k * limbs} &mut source, - 3.2, - 19.0, + 3.2, // Standard deviation + 19.0, // Truncatation bound ); - // Decrypt + // Final ciphertext: ct = (-a * s + m + e, a) + + // Decryption + + // DFT(ct[1] * s) + module.svp_apply_dft( + &mut buf_dft, + &s_ppol, + &ct, + 1, // Selects the second column of ct (ct[1]) + ); - // DFT(c1 * s) - module.svp_apply_dft(&mut buf_dft, &s_ppol, &ct, 1); // BIG(c1 * s) = IDFT(DFT(c1 * s)) module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); - // BIG(c1 * s) + c0 - module.vec_znx_big_add_small_inplace(&mut buf_big, &ct); + // BIG(c1 * s) + ct[0] + module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); - // m + e <- BIG(c1 * s + c0) + // m + e <- BIG(ct[1] * s + ct[0]) let mut res: VecZnx = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut carry); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs index 96a0df7..2f4b1fb 100644 --- a/base2k/examples/vector_matrix_product.rs +++ b/base2k/examples/vector_matrix_product.rs @@ -46,7 +46,7 @@ fn main() { module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); let mut res: VecZnx = module.new_vec_znx(1, limbs_vec); - module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf); let mut values_res: Vec = vec![i64::default(); n]; res.decode_vec_i64(0, log_base2k, log_k, &mut values_res); diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index c87c95d..e59fda1 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,8 +1,5 @@ -use std::cmp::min; - use crate::ffi::vec_znx; -use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_1, ffi_ternary_op_factory}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, assert_alignement}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -41,40 +38,80 @@ pub trait VecZnxBigOps { fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize; /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig); + fn vec_znx_big_add( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); /// Adds `a` to `b` and stores the result on `c`. - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + fn vec_znx_big_add_small( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnx, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ); /// Adds `a` to `b` and stores the result on `b`. - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); /// Subtracts `a` to `b` and stores the result on `c`. - fn vec_znx_big_sub(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig); + fn vec_znx_big_sub( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); /// Subtracts `b` to `a` and stores the result on `c`. - fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig); + fn vec_znx_big_sub_small_a( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnx, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ); /// Subtracts `a` to `b` and stores the result on `b`. - fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); /// Subtracts `b` to `a` and stores the result on `c`. - fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnx); + fn vec_znx_big_sub_small_b( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnx, + col_b: usize, + ); /// Subtracts `b` to `a` and stores the result on `b`. - fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize_tmp_bytes(&self, cols: usize) -> usize; + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; /// Normalizes `a` and stores the result on `b`. /// @@ -82,13 +119,21 @@ pub trait VecZnxBigOps { /// /// * `log_base2k`: normalization basis. /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); + fn vec_znx_big_normalize( + &self, + log_base2k: usize, + res: &mut VecZnx, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + tmp_bytes: &mut [u8], + ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. - fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig); + fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig); + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, col_a: usize); } impl VecZnxBigOps for Module { @@ -108,170 +153,267 @@ impl VecZnxBigOps for Module { VecZnxBig::bytes_of(self, cols, size) } - fn vec_znx_big_add(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_add, - ); - apply_binary_op::, VecZnxBig, VecZnxBig, false>(self, c, a, b, op); - } - - fn vec_znx_big_add_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_add(self, &mut *b_ptr, a, &*b_ptr); - } - } - - fn vec_znx_big_sub(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnxBig) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_sub, - ); - apply_binary_op::, VecZnxBig, VecZnxBig, true>(self, c, a, b, op); - } - - fn vec_znx_big_sub_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *b_ptr, a, &*b_ptr); - } - } - - fn vec_znx_big_sub_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnxBig) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_sub(self, &mut *b_ptr, &*b_ptr, a); - } - } - - fn vec_znx_big_sub_small_ba(&self, c: &mut VecZnxBig, a: &VecZnxBig, b: &VecZnx) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_sub, - ); - apply_binary_op::, VecZnxBig, VecZnx, true>(self, c, a, b, op); - } - - fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_sub_small_ba(self, &mut *b_ptr, &*b_ptr, a); - } - } - - fn vec_znx_big_sub_small_ab(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_sub, - ); - apply_binary_op::, VecZnx, VecZnxBig, true>(self, c, a, b, op); - } - - fn vec_znx_big_sub_small_ab_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_sub_small_ab(self, &mut *b_ptr, a, &*b_ptr); - } - } - - fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) { - let op = ffi_ternary_op_factory( - self.ptr, - c.size(), - c.sl(), - a.size(), - a.sl(), - b.size(), - b.sl(), - vec_znx::vec_znx_add, - ); - apply_binary_op::, VecZnx, VecZnxBig, false>(self, c, a, b, op); - } - - fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { - unsafe { - let b_ptr: *mut VecZnxBig = b as *mut VecZnxBig; - Self::vec_znx_big_add_small(self, &mut *b_ptr, a, &*b_ptr); - } - } - - fn vec_znx_big_normalize_tmp_bytes(&self, cols: usize) -> usize { - Self::vec_znx_normalize_tmp_bytes(self, cols) - } - - fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { + fn vec_znx_big_add( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ) { #[cfg(debug_assertions)] { - assert!(tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(&self, a.cols())); + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + } + } + + fn vec_znx_big_sub( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + } + } + + fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + } + } + + fn vec_znx_big_sub_small_b( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + b: &VecZnx, + col_b: usize, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + } + } + + fn vec_znx_big_sub_small_a( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnx, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_sub( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, col_a: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + } + } + + fn vec_znx_big_add_small( + &self, + res: &mut VecZnxBig, + col_res: usize, + a: &VecZnx, + col_a: usize, + b: &VecZnxBig, + col_b: usize, + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert_ne!(a.as_ptr(), b.as_ptr()); + } + unsafe { + vec_znx::vec_znx_add( + self.ptr, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + b.at_ptr(col_b, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } + + fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig, col_res: usize, a: &VecZnx, a_col: usize) { + unsafe { + let res_ptr: *mut VecZnxBig = res as *mut VecZnxBig; + Self::vec_znx_big_add_small(self, &mut *res_ptr, col_res, a, a_col, &*res_ptr, col_res); + } + } + + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + Self::vec_znx_normalize_tmp_bytes(self) + } + + fn vec_znx_big_normalize( + &self, + log_base2k: usize, + res: &mut VecZnx, + col_res: usize, + a: &VecZnxBig, + col_a: usize, + tmp_bytes: &mut [u8], + ) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self)); assert_alignement(tmp_bytes.as_ptr()); } - - let a_size: usize = a.size(); - let b_size: usize = b.size(); - let a_sl: usize = a.sl(); - let b_sl: usize = b.sl(); - let a_cols: usize = a.cols(); - let b_cols: usize = b.cols(); - let min_cols: usize = min(a_cols, b_cols); - (0..min_cols).for_each(|i| unsafe { + unsafe { vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, - b.at_mut_ptr(i, 0), - b_size as u64, - b_sl as u64, - a.at_ptr(i, 0), - a_size as u64, - a_sl as u64, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, tmp_bytes.as_mut_ptr(), ); - }); - - (min_cols..b_cols).for_each(|i| (0..b_size).for_each(|j| b.zero_at(i, j))); + } } - fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig) { - let op = ffi_binary_op_factory_type_1( - self.ptr, - k, - b.size(), - b.sl(), - a.size(), - a.sl(), - vec_znx::vec_znx_automorphism, - ); - apply_unary_op::>(self, b, a, op); + fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig, col_res: usize, a: &VecZnxBig, col_a: usize) { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), self.n()); + assert_eq!(res.n(), self.n()); + } + unsafe { + vec_znx::vec_znx_automorphism( + self.ptr, + k, + res.at_mut_ptr(col_res, 0), + res.size() as u64, + res.sl() as u64, + a.at_ptr(col_a, 0), + a.size() as u64, + a.sl() as u64, + ) + } } - fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, col_a: usize) { unsafe { let a_ptr: *mut VecZnxBig = a as *mut VecZnxBig; - Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, &*a_ptr); + Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); } } }