diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 8a5d09f..803f371 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -8,17 +8,17 @@ use sampling::source::Source; fn main() { let n: usize = 16; let log_base2k: usize = 18; - let limbs: usize = 3; - let msg_cols: usize = 2; - let log_scale: usize = msg_cols * log_base2k - 5; + let ct_size: usize = 3; + let msg_size: usize = 2; + let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); + let mut carry: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes(1)); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); - let mut res: VecZnx = module.new_vec_znx(1, limbs); + let mut res: VecZnx = module.new_vec_znx(1, ct_size); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) let mut s: Scalar = Scalar::new(n); @@ -31,8 +31,8 @@ fn main() { module.svp_prepare(&mut s_ppol, &s); // a <- Z_{2^prec}[X]/(X^{N}+1) - let mut a: VecZnx = module.new_vec_znx(1, limbs); - module.fill_uniform(log_base2k, &mut a, 0, limbs, &mut source); + let mut a: VecZnx = module.new_vec_znx(1, ct_size); + module.fill_uniform(log_base2k, &mut a, 0, ct_size, &mut source); // Scratch space for DFT values let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.size()); @@ -48,7 +48,7 @@ fn main() { println!("{:?}", buf_big.raw()); - let mut m: VecZnx = module.new_vec_znx(1, msg_cols); + let mut m: VecZnx = module.new_vec_znx(1, msg_size); let mut want: Vec = vec![0; n]; want.iter_mut() @@ -64,14 +64,14 @@ fn main() { println!("{:?}", buf_big.raw()); // b <- normalize(buf_big) + e - let mut b: VecZnx = module.new_vec_znx(1, limbs); + let mut b: VecZnx = module.new_vec_znx(1, ct_size); module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry); b.print(n); module.add_normal( log_base2k, &mut b, 0, - log_base2k * limbs, + log_base2k * ct_size, &mut source, 3.2, 19.0, diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 530cb54..4b4e54e 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,7 +1,8 @@ -use crate::ffi::vec_znx_big::vec_znx_big_t; -use crate::ffi::{vec_znx, vec_znx_big}; -use crate::internals::{apply_binary_op, ffi_ternary_op_factory}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; +use std::cmp::min; + +use crate::ffi::vec_znx; +use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_1, ffi_ternary_op_factory}; +use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, assert_alignement}; pub trait VecZnxBigOps { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -73,7 +74,7 @@ pub trait VecZnxBigOps { fn vec_znx_big_sub_small_ba_inplace(&self, b: &mut VecZnxBig, a: &VecZnx); /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; + fn vec_znx_big_normalize_tmp_bytes(&self, cols: usize) -> usize; /// Normalizes `a` and stores the result on `b`. /// @@ -83,29 +84,6 @@ pub trait VecZnxBigOps { /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]); - /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_range_normalize_base2k]. - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize; - - /// Normalize `a`, taking into account column interleaving and stores the result on `b`. - /// - /// # Arguments - /// - /// * `log_base2k`: normalization basis. - /// * `a_range_begin`: column to start. - /// * `a_range_end`: column to end. - /// * `a_range_step`: column step size. - /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_range_normalize_base2k_tmp_bytes]. - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - b: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ); - /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig); @@ -242,98 +220,58 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize } + fn vec_znx_big_normalize_tmp_bytes(&self, cols: usize) -> usize { + Self::vec_znx_normalize_tmp_bytes(self, cols) } fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) { - debug_assert!( - tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_big_normalize_tmp_bytes(self) - ); #[cfg(debug_assertions)] { - assert_alignement(tmp_bytes.as_ptr()) + assert!(tmp_bytes.len() >= Self::vec_znx_big_normalize_tmp_bytes(&self, a.cols())); + assert_alignement(tmp_bytes.as_ptr()); } - unsafe { - vec_znx_big::vec_znx_big_normalize_base2k( + + let a_size: usize = a.size(); + let b_size: usize = b.sl(); + let a_sl: usize = a.size(); + let b_sl: usize = a.sl(); + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + (0..min_cols).for_each(|i| unsafe { + vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, - b.as_mut_ptr(), - b.size() as u64, - b.n() as u64, - a.ptr as *mut vec_znx_big_t, - a.size() as u64, + b.at_mut_ptr(i, 0), + b_size as u64, + b_sl as u64, + a.at_ptr(i, 0), + a_size as u64, + a_sl as u64, tmp_bytes.as_mut_ptr(), - ) - } + ); + }); + + (min_cols..b_cols).for_each(|i| (0..b_size).for_each(|j| b.zero_at(i, j))); } - fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize { - unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - - fn vec_znx_big_range_normalize_base2k( - &self, - log_base2k: usize, - res: &mut VecZnx, - a: &VecZnxBig, - a_range_begin: usize, - a_range_xend: usize, - a_range_step: usize, - tmp_bytes: &mut [u8], - ) { - debug_assert!( - tmp_bytes.len() >= Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self), - "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}", - tmp_bytes.len(), - Self::vec_znx_big_range_normalize_base2k_tmp_bytes(self) + fn vec_znx_big_automorphism(&self, k: i64, b: &mut VecZnxBig, a: &VecZnxBig) { + let op = ffi_binary_op_factory_type_1( + self.ptr, + k, + b.size(), + b.sl(), + a.size(), + a.sl(), + vec_znx::vec_znx_automorphism, ); - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()) - } - unsafe { - vec_znx_big::vec_znx_big_range_normalize_base2k( - self.ptr, - log_base2k as u64, - res.as_mut_ptr(), - res.size() as u64, - res.n() as u64, - a.ptr as *mut vec_znx_big_t, - a_range_begin as u64, - a_range_xend as u64, - a_range_step as u64, - tmp_bytes.as_mut_ptr(), - ); - } + apply_unary_op::>(self, b, a, op); } - fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) { + fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig) { unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - b.ptr as *mut vec_znx_big_t, - b.poly_count() as u64, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - ); - } - } - - fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) { - unsafe { - vec_znx_big::vec_znx_big_automorphism( - self.ptr, - gal_el, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - a.ptr as *mut vec_znx_big_t, - a.poly_count() as u64, - ); + let a_ptr: *mut VecZnxBig = a as *mut VecZnxBig; + Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, &*a_ptr); } } } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index f67b6b0..c7c8d85 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -1,7 +1,9 @@ +use std::cmp::min; + use crate::ffi::module::MODULE; use crate::ffi::vec_znx; use crate::internals::{apply_binary_op, apply_unary_op, ffi_binary_op_factory_type_0, ffi_binary_op_factory_type_1}; -use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, switch_degree}; +use crate::{Backend, Module, VecZnx, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, assert_alignement, switch_degree}; pub trait VecZnxOps { /// Allocates a new [VecZnx]. /// @@ -43,6 +45,12 @@ pub trait VecZnxOps { /// Returns the minimum number of bytes necessary for normalization. fn vec_znx_normalize_tmp_bytes(&self, cols: usize) -> usize; + /// Normalizes `a` and stores the result into `b`. + fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]); + + /// Normalizes `a` and stores the result into `a`. + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]); + /// Adds `a` to `b` and write the result on `c`. fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx); @@ -114,6 +122,44 @@ impl VecZnxOps for Module { unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize * cols } } + fn vec_znx_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnx, tmp_bytes: &mut [u8]) { + #[cfg(debug_assertions)] + { + assert!(tmp_bytes.len() >= Self::vec_znx_normalize_tmp_bytes(&self, a.cols())); + assert_alignement(tmp_bytes.as_ptr()); + } + + let a_size: usize = a.size(); + let b_size: usize = b.sl(); + let a_sl: usize = a.size(); + let b_sl: usize = a.sl(); + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); + let min_cols: usize = min(a_cols, b_cols); + (0..min_cols).for_each(|i| unsafe { + vec_znx::vec_znx_normalize_base2k( + self.ptr, + log_base2k as u64, + b.at_mut_ptr(i, 0), + b_size as u64, + b_sl as u64, + a.at_ptr(i, 0), + a_size as u64, + a_sl as u64, + tmp_bytes.as_mut_ptr(), + ); + }); + + (min_cols..b_cols).for_each(|i| (0..b_size).for_each(|j| b.zero_at(i, j))); + } + + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { + unsafe { + let a_ptr: *mut VecZnx = a as *mut VecZnx; + Self::vec_znx_normalize(self, log_base2k, &mut *a_ptr, &*a_ptr, tmp_bytes); + } + } + fn vec_znx_add(&self, c: &mut VecZnx, a: &VecZnx, b: &VecZnx) { let op = ffi_ternary_op_factory( self.ptr,