diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 16b7d3a..b9d78f4 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -20,7 +20,7 @@ fn main() { let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: ScalarZnx> = module.new_scalar(1); + let mut s: ScalarZnx> = module.new_scalar_znx(1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain diff --git a/base2k/src/ffi/svp.rs b/base2k/src/ffi/svp.rs index 9d4999f..08b2da1 100644 --- a/base2k/src/ffi/svp.rs +++ b/base2k/src/ffi/svp.rs @@ -39,8 +39,10 @@ unsafe extern "C" { module: *const MODULE, res: *const VEC_ZNX_DFT, res_size: u64, + res_cols: u64, ppol: *const SVP_PPOL, a: *const VEC_ZNX_DFT, a_size: u64, + a_cols: u64, ); } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index f3b2525..450a69f 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -177,7 +177,7 @@ impl Scratch { } } - pub fn tmp_scalar_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { + pub fn tmp_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::()); unsafe { @@ -188,6 +188,24 @@ impl Scratch { } } + pub fn tmp_scalar(&mut self, module: &Module, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols)); + + ( + ScalarZnx::from_data(take_slice, module.n(), cols), + Self::new(rem_slice), + ) + } + + pub fn tmp_scalar_dft(&mut self, module: &Module, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols)); + + ( + ScalarZnxDft::from_data(take_slice, module.n(), cols), + Self::new(rem_slice), + ) + } + pub fn tmp_vec_znx_dft( &mut self, module: &Module, diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index ae0cbb5..85e6264 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -279,7 +279,7 @@ impl MatZnxDftOps for Module { ); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_tmp_bytes( + let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes( res.size(), a.size(), b.rows(), diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index 731add3..dde286a 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -98,24 +98,34 @@ impl>> ScalarZnx { pub type ScalarZnxOwned = ScalarZnx>; +pub(crate) fn bytes_of_scalar_znx(module: &Module, cols: usize) -> usize { + ScalarZnxOwned::bytes_of::(module.n(), cols) +} + pub trait ScalarZnxAlloc { - fn bytes_of_scalar(&self, cols: usize) -> usize; - fn new_scalar(&self, cols: usize) -> ScalarZnxOwned; - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; + fn bytes_of_scalar_znx(&self, cols: usize) -> usize; + fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned; + fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; } impl ScalarZnxAlloc for Module { - fn bytes_of_scalar(&self, cols: usize) -> usize { + fn bytes_of_scalar_znx(&self, cols: usize) -> usize { ScalarZnxOwned::bytes_of::(self.n(), cols) } - fn new_scalar(&self, cols: usize) -> ScalarZnxOwned { + fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned { ScalarZnxOwned::new::(self.n(), cols) } - fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { + fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { ScalarZnxOwned::new_from_bytes::(self.n(), cols, bytes) } } +impl ScalarZnx { + pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { + Self { data, n, cols } + } +} + pub trait ScalarZnxToRef { fn to_ref(&self) -> ScalarZnx<&[u8]>; } diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index c93609f..3626625 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -52,6 +52,10 @@ impl> ZnxView for ScalarZnxDft { type Scalar = f64; } +pub(crate) fn bytes_of_scalar_znx_dft(module: &Module, cols: usize) -> usize { + ScalarZnxDftOwned::bytes_of(module, cols) +} + impl>, B: Backend> ScalarZnxDft { pub(crate) fn bytes_of(module: &Module, cols: usize) -> usize { unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } @@ -79,6 +83,17 @@ impl>, B: Backend> ScalarZnxDft { } } +impl ScalarZnxDft { + pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { + Self { + data, + n, + cols, + _phantom: PhantomData, + } + } +} + pub type ScalarZnxDftOwned = ScalarZnxDft, B>; pub trait ScalarZnxDftToRef { diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index f5f8f7f..f02fa03 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -71,9 +71,11 @@ impl ScalarZnxDftOps for Module { self.ptr, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, res.size() as u64, + res.cols() as u64, a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, b.at_ptr(b_col, 0) as *const vec_znx_dft_t, b.size() as u64, + b.cols() as u64, ) } } @@ -90,9 +92,11 @@ impl ScalarZnxDftOps for Module { self.ptr, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, res.size() as u64, + res.cols() as u64, a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, res.at_ptr(res_col, 0) as *const vec_znx_dft_t, res.size() as u64, + res.cols() as u64, ) } } diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 8f70272..f5f220e 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -2,6 +2,7 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned}; use std::marker::PhantomData; +use std::{cmp::min, fmt}; pub struct VecZnxBig { data: D, @@ -162,3 +163,38 @@ impl VecZnxBigToRef for VecZnxBig<&[u8], B> { } } } + +impl> fmt::Display for VecZnxBig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "VecZnx(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 169c66a..933deb3 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -528,7 +528,7 @@ impl VecZnxBigOps for Module { // assert_alignement(tmp_bytes.as_ptr()); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(::vec_znx_big_normalize_tmp_bytes( + let (tmp_bytes, _) = scratch.tmp_slice(::vec_znx_big_normalize_tmp_bytes( &self, )); unsafe { diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 83b7c26..927e39e 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -141,7 +141,7 @@ impl VecZnxDftOps for Module { let mut res_mut = res.to_mut(); let a_ref = a.to_ref(); - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_idft_tmp_bytes()); + let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes()); let min_size: usize = min(res_mut.size(), a_ref.size()); diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index cdabe24..c80e9f1 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -175,7 +175,7 @@ impl VecZnxOps for Module { assert_eq!(res.n(), self.n()); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes()); + let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes()); unsafe { vec_znx::vec_znx_normalize_base2k( @@ -203,7 +203,7 @@ impl VecZnxOps for Module { assert_eq!(a.n(), self.n()); } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes()); + let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes()); unsafe { vec_znx::vec_znx_normalize_base2k( diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 5230dfd..94da450 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -171,7 +171,7 @@ where let k_rem: usize = k % log_base2k; if k_rem != 0 { - let (carry, _) = scratch.tmp_scalar_slice::(rsh_tmp_bytes::(n)); + let (carry, _) = scratch.tmp_slice::(rsh_tmp_bytes::(n)); unsafe { std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::());