diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 0f75ef3..07fe1c6 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Module, Sampling, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, + Encoding, FFT64, Module, Sampling, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, + VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, }; use itertools::izip; use sampling::source::Source; @@ -19,14 +19,14 @@ fn main() { let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: Scalar = Scalar::new(n); - s.fill_ternary_prob(0.5, &mut source); + let mut s: Scalar = module.new_scalar(1); + s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: ScalarZnxDft = module.new_scalar_znx_dft(); + let mut s_dft: ScalarZnxDft = module.new_scalar_znx_dft(s.cols()); // s_dft <- DFT(s) - module.svp_prepare(&mut s_dft, &s); + module.svp_prepare(&mut s_dft, 0, &s, 0); // Allocates a VecZnx with two columns: ct=(0, 0) let mut ct: VecZnx = module.new_vec_znx( @@ -48,6 +48,7 @@ fn main() { &mut buf_dft, // DFT(ct[1] * s) 0, // Selects the first column of res &s_dft, // DFT(s) + 0, // Selects the first column of s_dft &ct, 1, // Selects the second column of ct ); @@ -106,6 +107,7 @@ fn main() { &mut buf_dft, 0, // Selects the first column of res. &s_dft, + 0, &ct, 1, // Selects the second column of ct (ct[1]) ); diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index f57e482..3fa0bbe 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -5,7 +5,9 @@ pub mod ffi; pub mod mat_znx_dft; pub mod module; pub mod sampling; +pub mod scalar_znx; pub mod scalar_znx_dft; +pub mod scalar_znx_dft_ops; pub mod stats; pub mod vec_znx; pub mod vec_znx_big; @@ -19,8 +21,11 @@ pub use encoding::*; pub use mat_znx_dft::*; pub use module::*; pub use sampling::*; +#[allow(unused_imports)] +pub use scalar_znx::*; pub use scalar_znx_dft::*; #[allow(unused_imports)] +pub use scalar_znx_dft_ops::*; pub use stats::*; pub use vec_znx::*; pub use vec_znx_big::*; @@ -50,13 +55,13 @@ pub fn assert_alignement(ptr: *const T) { pub fn cast(data: &[T]) -> &[V] { let ptr: *const V = data.as_ptr() as *const V; - let len: usize = data.len() / std::mem::size_of::(); + let len: usize = data.len() / size_of::(); unsafe { std::slice::from_raw_parts(ptr, len) } } pub fn cast_mut(data: &[T]) -> &mut [V] { let ptr: *mut V = data.as_ptr() as *mut V; - let len: usize = data.len() / std::mem::size_of::(); + let len: usize = data.len() / size_of::(); unsafe { std::slice::from_raw_parts_mut(ptr, len) } } @@ -70,7 +75,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { align ); assert_eq!( - (size * std::mem::size_of::()) % align, + (size * size_of::()) % align, 0, "size={} must be a multiple of align={}", size, @@ -98,22 +103,25 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec { /// Size of T * size msut be a multiple of [DEFAULTALIGN]. pub fn alloc_aligned_custom(size: usize, align: usize) -> Vec { assert_eq!( - (size * std::mem::size_of::()) % align, + (size * size_of::()) % align, 0, "size={} must be a multiple of align={}", size, align ); - let mut vec_u8: Vec = alloc_aligned_custom_u8(std::mem::size_of::() * size, align); + let mut vec_u8: Vec = alloc_aligned_custom_u8(size_of::() * size, align); let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T; - let len: usize = vec_u8.len() / std::mem::size_of::(); - let cap: usize = vec_u8.capacity() / std::mem::size_of::(); + let len: usize = vec_u8.len() / size_of::(); + let cap: usize = vec_u8.capacity() / size_of::(); std::mem::forget(vec_u8); unsafe { Vec::from_raw_parts(ptr, len, cap) } } -/// Allocates an aligned of size equal to the smallest multiple -/// of [DEFAULTALIGN] that is equal or greater to `size`. +/// Allocates an aligned vector of size equal to the smallest multiple +/// of [DEFAULTALIGN]/size_of::() that is equal or greater to `size`. pub fn alloc_aligned(size: usize) -> Vec { - alloc_aligned_custom::(size + (size % DEFAULTALIGN), DEFAULTALIGN) + alloc_aligned_custom::( + size + (size % (DEFAULTALIGN / size_of::())), + DEFAULTALIGN, + ) } diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 9b5e2ca..44d44df 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -160,7 +160,7 @@ pub trait MatZnxDftOps { /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `a`: [MatZnxDft] on which the values are encoded. /// * `row_i`: the index of the row to extract. - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize); + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft); /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. /// @@ -170,7 +170,7 @@ pub trait MatZnxDftOps { /// * `a_size`: number of size of the input [VecZnx]. /// * `rows`: number of rows of the input [MatZnxDft]. /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; + fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize; /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// @@ -404,7 +404,7 @@ impl MatZnxDftOps for Module { } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, row_i: usize) { + fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, row_i: usize, a: &MatZnxDft) { #[cfg(debug_assertions)] { assert_eq!(a.n(), b.n()); @@ -422,14 +422,14 @@ impl MatZnxDftOps for Module { } } - fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize { + fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize { unsafe { vmp::vmp_apply_dft_tmp_bytes( self.ptr, res_size as u64, a_size as u64, - gct_rows as u64, - gct_size as u64, + b_rows as u64, + b_size as u64, ) as usize } } @@ -595,7 +595,7 @@ mod tests { assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i); + module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0); assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs new file mode 100644 index 0000000..df3e6d1 --- /dev/null +++ b/base2k/src/scalar_znx.rs @@ -0,0 +1,113 @@ +use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, GetZnxBase, Module, VecZnx}; +use rand::seq::SliceRandom; +use rand_core::RngCore; +use rand_distr::{Distribution, weighted::WeightedIndex}; +use sampling::source::Source; + +pub const SCALAR_ZNX_ROWS: usize = 1; +pub const SCALAR_ZNX_SIZE: usize = 1; + +pub struct Scalar { + pub inner: ZnxBase, +} + +impl GetZnxBase for Scalar { + fn znx(&self) -> &ZnxBase { + &self.inner + } + + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner + } +} + +impl ZnxInfos for Scalar {} + +impl ZnxAlloc for Scalar { + type Scalar = i64; + + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { + Self { + inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes), + } + } + + fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { + debug_assert_eq!( + _rows, SCALAR_ZNX_ROWS, + "rows != {} not supported for Scalar", + SCALAR_ZNX_ROWS + ); + debug_assert_eq!( + _size, SCALAR_ZNX_SIZE, + "rows != {} not supported for Scalar", + SCALAR_ZNX_SIZE + ); + module.n() * cols * std::mem::size_of::() + } +} + +impl ZnxLayout for Scalar { + type Scalar = i64; +} + +impl ZnxSliceSize for Scalar { + fn sl(&self) -> usize { + self.n() + } +} + +impl Scalar { + pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { + let choices: [i64; 3] = [-1, 0, 1]; + let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; + let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); + self.at_mut(col, 0) + .iter_mut() + .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); + } + + pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) { + assert!(hw <= self.n()); + self.at_mut(col, 0)[..hw] + .iter_mut() + .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); + self.at_mut(col, 0).shuffle(source); + } + + pub fn alias_as_vec_znx(&self) -> VecZnx { + VecZnx { + inner: ZnxBase { + n: self.n(), + rows: 1, + cols: 1, + size: 1, + data: Vec::new(), + ptr: self.ptr() as *mut u8, + }, + } + } +} + +pub trait ScalarOps { + fn bytes_of_scalar(&self, cols: usize) -> usize; + fn new_scalar(&self, cols: usize) -> Scalar; + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> Scalar; + fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar; +} + +impl ScalarOps for Module { + fn bytes_of_scalar(&self, cols: usize) -> usize { + Scalar::bytes_of(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE) + } + fn new_scalar(&self, cols: usize) -> Scalar { + Scalar::new(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE) + } + fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> Scalar { + Scalar::from_bytes(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) + } + fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar { + Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) + } +} diff --git a/base2k/src/scalar_znx_dft.rs b/base2k/src/scalar_znx_dft.rs index 07e156d..ffb54b5 100644 --- a/base2k/src/scalar_znx_dft.rs +++ b/base2k/src/scalar_znx_dft.rs @@ -1,279 +1,66 @@ use std::marker::PhantomData; -use crate::ffi::svp::{self, svp_ppol_t}; -use crate::ffi::vec_znx_dft::vec_znx_dft_t; -use crate::znx_base::{ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; -use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement, cast_mut}; -use rand::seq::SliceRandom; -use rand_core::RngCore; -use rand_distr::{Distribution, weighted::WeightedIndex}; -use sampling::source::Source; +use crate::ffi::svp; +use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, GetZnxBase, Module}; -pub struct Scalar { - pub n: usize, - pub data: Vec, - pub ptr: *mut i64, -} - -impl Module { - pub fn new_scalar(&self) -> Scalar { - Scalar::new(self.n()) - } -} - -impl Scalar { - pub fn new(n: usize) -> Self { - let mut data: Vec = alloc_aligned::(n); - let ptr: *mut i64 = data.as_mut_ptr(); - Self { - n: n, - data: data, - ptr: ptr, - } - } - - pub fn n(&self) -> usize { - self.n - } - - pub fn bytes_of(n: usize) -> usize { - n * std::mem::size_of::() - } - - pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self { - let size: usize = Self::bytes_of(n); - debug_assert!( - bytes.len() == size, - "invalid buffer: bytes.len()={} < self.bytes_of(n={})={}", - bytes.len(), - n, - size - ); - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()) - } - unsafe { - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: ptr, - } - } - } - - pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self { - let size: usize = Self::bytes_of(n); - debug_assert!( - bytes.len() == size, - "invalid buffer: bytes.len()={} < self.bytes_of(n={})={}", - bytes.len(), - n, - size - ); - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()) - } - let bytes_i64: &mut [i64] = cast_mut::(bytes); - let ptr: *mut i64 = bytes_i64.as_mut_ptr(); - Self { - n: n, - data: Vec::new(), - ptr: ptr, - } - } - - pub fn as_ptr(&self) -> *const i64 { - self.ptr - } - - pub fn raw(&self) -> &[i64] { - unsafe { std::slice::from_raw_parts(self.ptr, self.n) } - } - - pub fn raw_mut(&self) -> &mut [i64] { - unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) } - } - - pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - let choices: [i64; 3] = [-1, 0, 1]; - let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; - let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); - self.data - .iter_mut() - .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); - } - - pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - assert!(hw <= self.n()); - self.data[..hw] - .iter_mut() - .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); - self.data.shuffle(source); - } - - pub fn as_vec_znx(&self) -> VecZnx { - VecZnx { - inner: ZnxBase { - n: self.n, - rows: 1, - cols: 1, - size: 1, - data: Vec::new(), - ptr: self.ptr as *mut u8, - }, - } - } -} - -pub trait ScalarOps { - fn bytes_of_scalar(&self) -> usize; - fn new_scalar(&self) -> Scalar; - fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar; - fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar; -} -impl ScalarOps for Module { - fn bytes_of_scalar(&self) -> usize { - Scalar::bytes_of(self.n()) - } - fn new_scalar(&self) -> Scalar { - Scalar::new(self.n()) - } - fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar { - Scalar::from_bytes(self.n(), bytes) - } - fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar { - Scalar::from_bytes_borrow(self.n(), tmp_bytes) - } -} +pub const SCALAR_ZNX_DFT_ROWS: usize = 1; +pub const SCALAR_ZNX_DFT_SIZE: usize = 1; pub struct ScalarZnxDft { - pub n: usize, - pub data: Vec, - pub ptr: *mut u8, + pub inner: ZnxBase, _marker: PhantomData, } -/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft]. -/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. -impl ScalarZnxDft { - pub fn new(module: &Module) -> Self { - module.new_scalar_znx_dft() +impl GetZnxBase for ScalarZnxDft { + fn znx(&self) -> &ZnxBase { + &self.inner } - /// Returns the ring degree of the [SvpPPol]. - pub fn n(&self) -> usize { - self.n + fn znx_mut(&mut self) -> &mut ZnxBase { + &mut self.inner } +} - pub fn bytes_of(module: &Module) -> usize { - module.bytes_of_scalar_znx_dft() - } +impl ZnxInfos for ScalarZnxDft {} - pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_alignement(bytes.as_ptr()); - assert_eq!(bytes.len(), module.bytes_of_scalar_znx_dft()); - } - unsafe { - Self { - n: module.n(), - data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()), - ptr: bytes.as_mut_ptr(), - _marker: PhantomData, - } - } - } +impl ZnxAlloc for ScalarZnxDft { + type Scalar = u8; - pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_alignement(tmp_bytes.as_ptr()); - assert_eq!(tmp_bytes.len(), module.bytes_of_scalar_znx_dft()); - } + fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { Self { - n: module.n(), - data: Vec::new(), - ptr: tmp_bytes.as_mut_ptr(), + inner: ZnxBase::from_bytes_borrow( + module.n(), + SCALAR_ZNX_DFT_ROWS, + cols, + SCALAR_ZNX_DFT_SIZE, + bytes, + ), _marker: PhantomData, } } - /// Returns the number of cols of the [SvpPPol], which is always 1. - pub fn cols(&self) -> usize { - 1 + fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { + debug_assert_eq!( + _rows, SCALAR_ZNX_DFT_ROWS, + "rows != {} not supported for ScalarZnxDft", + SCALAR_ZNX_DFT_ROWS + ); + debug_assert_eq!( + _size, SCALAR_ZNX_DFT_SIZE, + "rows != {} not supported for ScalarZnxDft", + SCALAR_ZNX_DFT_SIZE + ); + unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols } } } -pub trait ScalarZnxDftOps { - /// Allocates a new [SvpPPol]. - fn new_scalar_znx_dft(&self) -> ScalarZnxDft; - - /// Returns the minimum number of bytes necessary to allocate - /// a new [SvpPPol] through [SvpPPol::from_bytes] ro. - fn bytes_of_scalar_znx_dft(&self) -> usize; - - /// Allocates a new [SvpPPol] from an array of bytes. - /// The array of bytes is owned by the [SvpPPol]. - /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft; - - /// Allocates a new [SvpPPol] from an array of bytes. - /// The array of bytes is borrowed by the [SvpPPol]. - /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] - fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft; - - /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. - fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft, a: &Scalar); - - /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of - /// the [VecZnxDft] is multiplied with [SvpPPol]. - fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, b: &VecZnx, b_col: usize); +impl ZnxLayout for ScalarZnxDft { + type Scalar = f64; } -impl ScalarZnxDftOps for Module { - fn new_scalar_znx_dft(&self) -> ScalarZnxDft { - let mut data: Vec = alloc_aligned::(self.bytes_of_scalar_znx_dft()); - let ptr: *mut u8 = data.as_mut_ptr(); - ScalarZnxDft:: { - data: data, - ptr: ptr, - n: self.n(), - _marker: PhantomData, - } - } - - fn bytes_of_scalar_znx_dft(&self) -> usize { - unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } - } - - fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft { - ScalarZnxDft::from_bytes(self, bytes) - } - - fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft { - ScalarZnxDft::from_bytes_borrow(self, tmp_bytes) - } - - fn svp_prepare(&self, res: &mut ScalarZnxDft, a: &Scalar) { - unsafe { svp::svp_prepare(self.ptr, res.ptr as *mut svp_ppol_t, a.as_ptr()) } - } - - fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, b: &VecZnx, b_col: usize) { - unsafe { - svp::svp_apply_dft( - self.ptr, - res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, - res.size() as u64, - a.ptr as *const svp_ppol_t, - b.at_ptr(b_col, 0), - b.size() as u64, - b.sl() as u64, - ) - } +impl ZnxSliceSize for ScalarZnxDft { + fn sl(&self) -> usize { + self.n() } } diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs new file mode 100644 index 0000000..4fbe99d --- /dev/null +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -0,0 +1,63 @@ +use crate::ffi::svp::{self, svp_ppol_t}; +use crate::ffi::vec_znx_dft::vec_znx_dft_t; +use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; +use crate::{Backend, FFT64, Module, SCALAR_ZNX_DFT_ROWS, SCALAR_ZNX_DFT_SIZE, Scalar, ScalarZnxDft, VecZnx, VecZnxDft}; + +pub trait ScalarZnxDftOps { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft; + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDft; + fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft; + fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize); + fn svp_apply_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &ScalarZnxDft, a_col: usize, b: &VecZnx, b_col: usize); +} + +impl ScalarZnxDftOps for Module { + fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft { + ScalarZnxDft::::new(&self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE) + } + + fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { + ScalarZnxDft::::bytes_of(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE) + } + + fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxDft { + ScalarZnxDft::from_bytes(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) + } + + fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft { + ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes) + } + + fn svp_prepare(&self, res: &mut ScalarZnxDft, res_col: usize, a: &Scalar, a_col: usize) { + unsafe { + svp::svp_prepare( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut svp_ppol_t, + a.at_ptr(a_col, 0), + ) + } + } + + fn svp_apply_dft( + &self, + res: &mut VecZnxDft, + res_col: usize, + a: &ScalarZnxDft, + a_col: usize, + b: &VecZnx, + b_col: usize, + ) { + unsafe { + svp::svp_apply_dft( + self.ptr, + res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, + res.size() as u64, + a.at_ptr(a_col, 0) as *const svp_ppol_t, + b.at_ptr(b_col, 0), + b.size() as u64, + b.sl() as u64, + ) + } + } +} diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 09ee971..a9dd378 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -26,7 +26,7 @@ impl ZnxAlloc for VecZnxDft { type Scalar = u8; fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - VecZnxDft { + Self { inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), _marker: PhantomData, } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 7ee1529..6365ad3 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -47,47 +47,47 @@ pub trait VecZnxOps { &self, log_base2k: usize, res: &mut VecZnx, - col_res: usize, + res_col: usize, a: &VecZnx, - col_a: usize, + a_col: usize, tmp_bytes: &mut [u8], ); /// Normalizes the selected column of `a`. - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]); + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]); /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`. - fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize); + fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); /// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`. - fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`. - fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize); + fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize); /// Subtracts the selected column of `a` to the selected column of `res`. - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`. - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); // Negates the selected column of `a` and stores the result on the selected column of `res`. - fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Negates the selected column of `a`. - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize); + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize); /// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`. - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Multiplies the selected column of `a` by X^k. - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize); + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); /// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`. - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize); + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize); /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize); + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize); /// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`. /// @@ -95,7 +95,7 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of b have the same ring degree /// and that b.n() * b.len() <= a.n() - fn vec_znx_split(&self, res: &mut Vec, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx); + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx); /// Merges the subrings of the selected column of `a` into the selected column of `res`. /// @@ -103,7 +103,7 @@ pub trait VecZnxOps { /// /// This method requires that all [VecZnx] of a have the same ring degree /// and that a.n() * a.len() <= b.n() - fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec, col_a: usize); + fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec, a_col: usize); } impl VecZnxOps for Module { @@ -131,9 +131,9 @@ impl VecZnxOps for Module { &self, log_base2k: usize, res: &mut VecZnx, - col_res: usize, + res_col: usize, a: &VecZnx, - col_a: usize, + a_col: usize, tmp_bytes: &mut [u8], ) { #[cfg(debug_assertions)] @@ -147,10 +147,10 @@ impl VecZnxOps for Module { vec_znx::vec_znx_normalize_base2k( self.ptr, log_base2k as u64, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, tmp_bytes.as_mut_ptr(), @@ -158,22 +158,22 @@ impl VecZnxOps for Module { } } - fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]) { + fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; Self::vec_znx_normalize( self, log_base2k, &mut *a_ptr, - col_a, + a_col, &*a_ptr, - col_a, + a_col, tmp_bytes, ); } } - fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) { + fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -184,27 +184,27 @@ impl VecZnxOps for Module { unsafe { vec_znx::vec_znx_add( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, ) } } - fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } - fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) { + fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -215,34 +215,34 @@ impl VecZnxOps for Module { unsafe { vec_znx::vec_znx_sub( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - b.at_ptr(col_b, 0), + b.at_ptr(b_col, 0), b.size() as u64, b.sl() as u64, ) } } - fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); + Self::vec_znx_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col); } } - fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { unsafe { let res_ptr: *mut VecZnx = res as *mut VecZnx; - Self::vec_znx_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); + Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col); } } - fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -251,24 +251,24 @@ impl VecZnxOps for Module { unsafe { vec_znx::vec_znx_negate( self.ptr, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, ) } } - fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize) { + fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_negate(self, &mut *a_ptr, col_a, &*a_ptr, col_a); + Self::vec_znx_negate(self, &mut *a_ptr, a_col, &*a_ptr, a_col); } } - fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -278,24 +278,24 @@ impl VecZnxOps for Module { vec_znx::vec_znx_rotate( self.ptr, k, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, ) } } - fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) { + fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_rotate(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); + Self::vec_znx_rotate(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); } } - fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) { + fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) { #[cfg(debug_assertions)] { assert_eq!(a.n(), self.n()); @@ -305,24 +305,24 @@ impl VecZnxOps for Module { vec_znx::vec_znx_automorphism( self.ptr, k, - res.at_mut_ptr(col_res, 0), + res.at_mut_ptr(res_col, 0), res.size() as u64, res.sl() as u64, - a.at_ptr(col_a, 0), + a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, ) } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) { + fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) { unsafe { let a_ptr: *mut VecZnx = a as *mut VecZnx; - Self::vec_znx_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); + Self::vec_znx_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col); } } - fn vec_znx_split(&self, res: &mut Vec, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx) { + fn vec_znx_split(&self, res: &mut Vec, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx) { let (n_in, n_out) = (a.n(), res[0].n()); debug_assert!( @@ -339,16 +339,16 @@ impl VecZnxOps for Module { res.iter_mut().enumerate().for_each(|(i, bi)| { if i == 0 { - switch_degree(bi, col_res, a, col_a); - self.vec_znx_rotate(-1, buf, 0, a, col_a); + switch_degree(bi, res_col, a, a_col); + self.vec_znx_rotate(-1, buf, 0, a, a_col); } else { - switch_degree(bi, col_res, buf, col_a); - self.vec_znx_rotate_inplace(-1, buf, col_a); + switch_degree(bi, res_col, buf, a_col); + self.vec_znx_rotate_inplace(-1, buf, a_col); } }) } - fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec, col_a: usize) { + fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec, a_col: usize) { let (n_in, n_out) = (res.n(), a[0].n()); debug_assert!( @@ -364,10 +364,10 @@ impl VecZnxOps for Module { }); a.iter().enumerate().for_each(|(_, ai)| { - switch_degree(res, col_res, ai, col_a); - self.vec_znx_rotate_inplace(-1, res, col_res); + switch_degree(res, res_col, ai, a_col); + self.vec_znx_rotate_inplace(-1, res, res_col); }); - self.vec_znx_rotate_inplace(a.len() as i64, res, col_res); + self.vec_znx_rotate_inplace(a.len() as i64, res, res_col); } }