wip major refactoring (compiles & all test + example passing)

This commit is contained in:
Jean-Philippe Bossuat
2025-04-30 13:43:18 +02:00
parent 2cc51eee18
commit 6f7b93c7ca
18 changed files with 662 additions and 870 deletions

View File

@@ -23,10 +23,10 @@ fn main() {
s.fill_ternary_prob(0.5, &mut source); s.fill_ternary_prob(0.5, &mut source);
// Buffer to store s in the DFT domain // Buffer to store s in the DFT domain
let mut s_ppol: ScalarZnxDft<FFT64> = module.new_svp_ppol(); let mut s_dft: ScalarZnxDft<FFT64> = module.new_scalar_znx_dft();
// s_ppol <- DFT(s) // s_dft <- DFT(s)
module.svp_prepare(&mut s_ppol, &s); module.svp_prepare(&mut s_dft, &s);
// Allocates a VecZnx with two columns: ct=(0, 0) // Allocates a VecZnx with two columns: ct=(0, 0)
let mut ct: VecZnx = module.new_vec_znx( let mut ct: VecZnx = module.new_vec_znx(
@@ -46,16 +46,17 @@ fn main() {
// Applies DFT(ct[1]) * DFT(s) // Applies DFT(ct[1]) * DFT(s)
module.svp_apply_dft( module.svp_apply_dft(
&mut buf_dft, // DFT(ct[1] * s) &mut buf_dft, // DFT(ct[1] * s)
&s_ppol, // DFT(s) 0, // Selects the first column of res
&s_dft, // DFT(s)
&ct, &ct,
1, // Selects the second column of ct 1, // Selects the second column of ct
); );
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>) // Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
let mut buf_big: VecZnxBig<FFT64> = buf_dft.as_vec_znx_big(); let mut buf_big: VecZnxBig<FFT64> = buf_dft.alias_as_vec_znx_big();
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
// Creates a plaintext: VecZnx with 1 column // Creates a plaintext: VecZnx with 1 column
let mut m: VecZnx = module.new_vec_znx( let mut m: VecZnx = module.new_vec_znx(
@@ -103,13 +104,14 @@ fn main() {
// DFT(ct[1] * s) // DFT(ct[1] * s)
module.svp_apply_dft( module.svp_apply_dft(
&mut buf_dft, &mut buf_dft,
&s_ppol, 0, // Selects the first column of res.
&s_dft,
&ct, &ct,
1, // Selects the second column of ct (ct[1]) 1, // Selects the second column of ct (ct[1])
); );
// BIG(c1 * s) = IDFT(DFT(c1 * s)) // BIG(c1 * s) = IDFT(DFT(c1 * s))
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft); module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
// BIG(c1 * s) + ct[0] // BIG(c1 * s) + ct[0]
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);

View File

@@ -42,8 +42,8 @@ fn main() {
let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs_mat); let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs_mat);
module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf); module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf);
let mut c_big: VecZnxBig<FFT64> = c_dft.as_vec_znx_big(); let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft); module.vec_znx_idft_tmp_a(&mut c_big, 0, &mut c_dft, 0);
let mut res: VecZnx = module.new_vec_znx(1, limbs_vec); let mut res: VecZnx = module.new_vec_znx(1, limbs_vec);
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf); module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf);

View File

@@ -1,5 +1,6 @@
use crate::ffi::znx::znx_zero_i64_ref; use crate::ffi::znx::znx_zero_i64_ref;
use crate::{VecZnx, ZnxInfos, ZnxLayout}; use crate::znx_base::ZnxLayout;
use crate::{VecZnx, znx_base::ZnxInfos};
use itertools::izip; use itertools::izip;
use rug::{Assign, Float}; use rug::{Assign, Float};
use std::cmp::min; use std::cmp::min;
@@ -262,7 +263,10 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{Encoding, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout}; use crate::{
Encoding, FFT64, Module, VecZnx, VecZnxOps,
znx_base::{ZnxInfos, ZnxLayout},
};
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
@@ -273,7 +277,7 @@ mod tests {
let log_base2k: usize = 17; let log_base2k: usize = 17;
let size: usize = 5; let size: usize = 5;
let log_k: usize = size * log_base2k - 5; let log_k: usize = size * log_base2k - 5;
let mut a: VecZnx = VecZnx::new(&module, 2, size); let mut a: VecZnx = module.new_vec_znx(2, size);
let mut source: Source = Source::new([0u8; 32]); let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut(); let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
@@ -295,7 +299,7 @@ mod tests {
let log_base2k: usize = 17; let log_base2k: usize = 17;
let size: usize = 5; let size: usize = 5;
let log_k: usize = size * log_base2k - 5; let log_k: usize = size * log_base2k - 5;
let mut a: VecZnx = VecZnx::new(&module, 2, size); let mut a: VecZnx = module.new_vec_znx(2, size);
let mut source = Source::new([0u8; 32]); let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut(); let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);

View File

@@ -1,4 +1,3 @@
pub mod commons;
pub mod encoding; pub mod encoding;
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)] #[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
// Other modules and exports // Other modules and exports
@@ -12,9 +11,10 @@ pub mod vec_znx;
pub mod vec_znx_big; pub mod vec_znx_big;
pub mod vec_znx_big_ops; pub mod vec_znx_big_ops;
pub mod vec_znx_dft; pub mod vec_znx_dft;
pub mod vec_znx_dft_ops;
pub mod vec_znx_ops; pub mod vec_znx_ops;
pub mod znx_base;
pub use commons::*;
pub use encoding::*; pub use encoding::*;
pub use mat_znx_dft::*; pub use mat_znx_dft::*;
pub use module::*; pub use module::*;
@@ -26,7 +26,9 @@ pub use vec_znx::*;
pub use vec_znx_big::*; pub use vec_znx_big::*;
pub use vec_znx_big_ops::*; pub use vec_znx_big_ops::*;
pub use vec_znx_dft::*; pub use vec_znx_dft::*;
pub use vec_znx_dft_ops::*;
pub use vec_znx_ops::*; pub use vec_znx_ops::*;
pub use znx_base::*;
pub const GALOISGENERATOR: u64 = 5; pub const GALOISGENERATOR: u64 = 5;
pub const DEFAULTALIGN: usize = 64; pub const DEFAULTALIGN: usize = 64;
@@ -110,14 +112,8 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
unsafe { Vec::from_raw_parts(ptr, len, cap) } unsafe { Vec::from_raw_parts(ptr, len, cap) }
} }
// Allocates an aligned of size equal to the smallest power of two equal or greater to `size` that is /// Allocates an aligned of size equal to the smallest multiple
// at least as bit as DEFAULTALIGN / std::mem::size_of::<T>(). /// of [DEFAULTALIGN] that is equal or greater to `size`.
pub fn alloc_aligned<T>(size: usize) -> Vec<T> { pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>( alloc_aligned_custom::<T>(size + (size % DEFAULTALIGN), DEFAULTALIGN)
std::cmp::max(
size.next_power_of_two(),
DEFAULTALIGN / std::mem::size_of::<T>(),
),
DEFAULTALIGN,
)
} }

View File

@@ -1,103 +1,75 @@
use crate::ffi::vec_znx_big::vec_znx_big_t; use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp::{self, vmp_pmat_t}; use crate::ffi::vmp::{self, vmp_pmat_t};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement};
use std::marker::PhantomData; use std::marker::PhantomData;
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array. /// stored as a 3D matrix in the DFT domain in a single contiguous array.
/// Each col of the [VmpPMat] can be seen as a collection of [VecZnxDft]. /// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft].
/// ///
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat]. /// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
/// See the trait [VmpPMatOps] for additional information. /// See the trait [MatZnxDftOps] for additional information.
pub struct MatZnxDft<B: Backend> { pub struct MatZnxDft<B: Backend> {
/// Raw data, is empty if borrowing scratch space. pub inner: ZnxBase,
data: Vec<u8>,
/// Pointer to data. Can point to scratch space.
ptr: *mut u8,
/// The ring degree of each polynomial.
n: usize,
/// Number of rows
rows: usize,
/// Number of cols
cols: usize,
/// The number of small polynomials
size: usize,
_marker: PhantomData<B>, _marker: PhantomData<B>,
} }
impl<B: Backend> ZnxInfos for MatZnxDft<B> { impl<B: Backend> GetZnxBase for MatZnxDft<B> {
fn n(&self) -> usize { fn znx(&self) -> &ZnxBase {
self.n &self.inner
} }
fn rows(&self) -> usize { fn znx_mut(&mut self) -> &mut ZnxBase {
self.rows &mut self.inner
}
fn cols(&self) -> usize {
self.cols
}
fn size(&self) -> usize {
self.size
} }
} }
impl MatZnxDft<FFT64> { impl<B: Backend> ZnxInfos for MatZnxDft<B> {}
fn new(module: &Module<FFT64>, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, size)); impl ZnxSliceSize for MatZnxDft<FFT64> {
let ptr: *mut u8 = data.as_mut_ptr(); fn sl(&self) -> usize {
MatZnxDft::<FFT64> { self.n()
data: data, }
ptr: ptr, }
n: module.n(),
rows: rows, impl ZnxLayout for MatZnxDft<FFT64> {
cols: cols, type Scalar = f64;
size: size, }
impl<B: Backend> ZnxAlloc<B> for MatZnxDft<B> {
type Scalar = u8;
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
Self {
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes),
_marker: PhantomData, _marker: PhantomData,
} }
} }
pub fn as_ptr(&self) -> *const u8 { fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize {
self.ptr unsafe { vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols }
} }
}
pub fn as_mut_ptr(&self) -> *mut u8 { impl MatZnxDft<FFT64> {
self.ptr /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
}
pub fn borrowed(&self) -> bool {
self.data.len() == 0
}
/// Returns a non-mutable reference to the entire contiguous array of the [VmpPMat].
pub fn raw(&self) -> &[f64] {
let ptr: *const f64 = self.ptr as *const f64;
let size: usize = self.n() * self.poly_count();
unsafe { &std::slice::from_raw_parts(ptr, size) }
}
/// Returns a mutable reference of to the entire contiguous array of the [VmpPMat].
pub fn raw_mut(&self) -> &mut [f64] {
let ptr: *mut f64 = self.ptr as *mut f64;
let size: usize = self.n() * self.poly_count();
unsafe { std::slice::from_raw_parts_mut(ptr, size) }
}
/// Returns a copy of the backend array at index (i, j) of the [VmpPMat].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `row`: row index (i). /// * `row`: row index (i).
/// * `col`: col index (j). /// * `col`: col index (j).
pub fn at(&self, row: usize, col: usize) -> Vec<f64> { #[allow(dead_code)]
let mut res: Vec<f64> = alloc_aligned(self.n); fn at(&self, row: usize, col: usize) -> Vec<f64> {
let n: usize = self.n();
if self.n < 8 { let mut res: Vec<f64> = alloc_aligned(n);
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)]);
if n < 8 {
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]);
} else { } else {
(0..self.n >> 3).for_each(|blk| { (0..n >> 3).for_each(|blk| {
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
}); });
} }
@@ -105,6 +77,7 @@ impl MatZnxDft<FFT64> {
res res
} }
#[allow(dead_code)]
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
let nrows: usize = self.rows(); let nrows: usize = self.rows();
let nsize: usize = self.size(); let nsize: usize = self.size();
@@ -117,11 +90,11 @@ impl MatZnxDft<FFT64> {
} }
/// This trait implements methods for vector matrix product, /// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [VmpPMat]. /// that is, multiplying a [VecZnx] with a [MatZnxDft].
pub trait MatZnxDftOps<B: Backend> { pub trait MatZnxDftOps<B: Backend> {
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize;
/// Allocates a new [VmpPMat] with the given number of rows and columns. /// Allocates a new [MatZnxDft] with the given number of rows and columns.
/// ///
/// # Arguments /// # Arguments
/// ///
@@ -129,83 +102,83 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `size`: number of size (number of size of each [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]).
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>; fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>;
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous]. /// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. /// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
/// * `size`: number of size of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous]. /// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize; fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize;
/// Prepares a [VmpPMat] from a contiguous array of [i64]. /// Prepares a [MatZnxDft] from a contiguous array of [i64].
/// The helper struct [Matrix3D] can be used to contruct and populate /// The helper struct [Matrix3D] can be used to contruct and populate
/// the appropriate contiguous array. /// the appropriate contiguous array.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `b`: [VmpPMat] on which the values are encoded. /// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat]. /// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft].
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft<B>, a: &[i64], buf: &mut [u8]); fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft<B>, a: &[i64], buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a [VecZnx]. /// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `b`: [VmpPMat] on which the values are encoded. /// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat]. /// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft].
/// * `row_i`: the index of the row to prepare. /// * `row_i`: the index of the row to prepare.
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
/// ///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row(&self, b: &mut MatZnxDft<B>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); fn vmp_prepare_row(&self, b: &mut MatZnxDft<B>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxBig]. /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat]. /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft].
/// * `a`: [VmpPMat] on which the values are encoded. /// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract. /// * `row_i`: the index of the row to extract.
fn vmp_extract_row(&self, b: &mut VecZnxBig<B>, a: &MatZnxDft<B>, row_i: usize); fn vmp_extract_row(&self, b: &mut VecZnxBig<B>, a: &MatZnxDft<B>, row_i: usize);
/// Prepares the ith-row of [VmpPMat] from a [VecZnxDft]. /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `b`: [VmpPMat] on which the values are encoded. /// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [VmpPMat]. /// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
/// * `row_i`: the index of the row to prepare. /// * `row_i`: the index of the row to prepare.
/// ///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes]. /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<B>, a: &VecZnxDft<B>, row_i: usize); fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<B>, a: &VecZnxDft<B>, row_i: usize);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxDft]. /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat]. /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [VmpPMat] on which the values are encoded. /// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract. /// * `row_i`: the index of the row to extract.
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, row_i: usize); fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, row_i: usize);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `c_size`: number of size of the output [VecZnxDft]. /// * `c_size`: number of size of the output [VecZnxDft].
/// * `a_size`: number of size of the input [VecZnx]. /// * `a_size`: number of size of the input [VecZnx].
/// * `rows`: number of rows of the input [VmpPMat]. /// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [VmpPMat]. /// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// ///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
/// ///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size. /// `j` size, the output is a [VecZnx] of `j` size.
/// ///
/// If there is a mismatch between the dimensions the largest valid ones are used. /// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -221,17 +194,17 @@ pub trait MatZnxDftOps<B: Backend> {
/// ///
/// * `c`: the output of the vector matrix product, as a [VecZnxDft]. /// * `c`: the output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]); fn vmp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver. /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver.
/// ///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
/// ///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size. /// `j` size, the output is a [VecZnx] of `j` size.
/// ///
/// If there is a mismatch between the dimensions the largest valid ones are used. /// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -247,28 +220,28 @@ pub trait MatZnxDftOps<B: Backend> {
/// ///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]); fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft]. /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `c_size`: number of size of the output [VecZnxDft]. /// * `c_size`: number of size of the output [VecZnxDft].
/// * `a_size`: number of size of the input [VecZnxDft]. /// * `a_size`: number of size of the input [VecZnxDft].
/// * `rows`: number of rows of the input [VmpPMat]. /// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [VmpPMat]. /// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat]. /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
/// ///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
/// ///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size. /// `j` size, the output is a [VecZnx] of `j` size.
/// ///
/// If there is a mismatch between the dimensions the largest valid ones are used. /// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -284,18 +257,18 @@ pub trait MatZnxDftOps<B: Backend> {
/// ///
/// * `c`: the output of the vector matrix product, as a [VecZnxDft]. /// * `c`: the output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]); fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it. /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
/// ///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
/// ///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size. /// `j` size, the output is a [VecZnx] of `j` size.
/// ///
/// If there is a mismatch between the dimensions the largest valid ones are used. /// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -311,18 +284,18 @@ pub trait MatZnxDftOps<B: Backend> {
/// ///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft]. /// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product. /// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]); fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place. /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
/// ///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol]) /// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat]. /// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
/// ///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and /// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size. /// `j` size, the output is a [VecZnx] of `j` size.
/// ///
/// If there is a mismatch between the dimensions the largest valid ones are used. /// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -337,8 +310,8 @@ pub trait MatZnxDftOps<B: Backend> {
/// # Arguments /// # Arguments
/// ///
/// * `b`: the input and output of the vector matrix product, as a [VecZnxDft]. /// * `b`: the input and output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the right operand [VmpPMat] of the vector matrix product. /// * `a`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, buf: &mut [u8]); fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, buf: &mut [u8]);
} }
@@ -404,7 +377,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe { unsafe {
vmp::vmp_extract_row( vmp::vmp_extract_row(
self.ptr, self.ptr,
b.ptr as *mut vec_znx_big_t, b.as_mut_ptr() as *mut vec_znx_big_t,
a.as_ptr() as *const vmp_pmat_t, a.as_ptr() as *const vmp_pmat_t,
row_i as u64, row_i as u64,
a.rows() as u64, a.rows() as u64,
@@ -423,7 +396,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_prepare_row_dft( vmp::vmp_prepare_row_dft(
self.ptr, self.ptr,
b.as_mut_ptr() as *mut vmp_pmat_t, b.as_mut_ptr() as *mut vmp_pmat_t,
a.ptr as *const vec_znx_dft_t, a.as_ptr() as *const vec_znx_dft_t,
row_i as u64, row_i as u64,
b.rows() as u64, b.rows() as u64,
b.size() as u64, b.size() as u64,
@@ -440,7 +413,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe { unsafe {
vmp::vmp_extract_row_dft( vmp::vmp_extract_row_dft(
self.ptr, self.ptr,
b.ptr as *mut vec_znx_dft_t, b.as_mut_ptr() as *mut vec_znx_dft_t,
a.as_ptr() as *const vmp_pmat_t, a.as_ptr() as *const vmp_pmat_t,
row_i as u64, row_i as u64,
a.rows() as u64, a.rows() as u64,
@@ -470,7 +443,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe { unsafe {
vmp::vmp_apply_dft( vmp::vmp_apply_dft(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64, c.size() as u64,
a.as_ptr(), a.as_ptr(),
a.size() as u64, a.size() as u64,
@@ -492,7 +465,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe { unsafe {
vmp::vmp_apply_dft_add( vmp::vmp_apply_dft_add(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64, c.size() as u64,
a.as_ptr(), a.as_ptr(),
a.size() as u64, a.size() as u64,
@@ -526,9 +499,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft( vmp::vmp_apply_dft_to_dft(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64, c.size() as u64,
a.ptr as *const vec_znx_dft_t, a.as_ptr() as *const vec_znx_dft_t,
a.size() as u64, a.size() as u64,
b.as_ptr() as *const vmp_pmat_t, b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64, b.rows() as u64,
@@ -553,9 +526,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft_add( vmp::vmp_apply_dft_to_dft_add(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64, c.size() as u64,
a.ptr as *const vec_znx_dft_t, a.as_ptr() as *const vec_znx_dft_t,
a.size() as u64, a.size() as u64,
b.as_ptr() as *const vmp_pmat_t, b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64, b.rows() as u64,
@@ -574,9 +547,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft( vmp::vmp_apply_dft_to_dft(
self.ptr, self.ptr,
b.ptr as *mut vec_znx_dft_t, b.as_mut_ptr() as *mut vec_znx_dft_t,
b.size() as u64, b.size() as u64,
b.ptr as *mut vec_znx_dft_t, b.as_ptr() as *mut vec_znx_dft_t,
b.size() as u64, b.size() as u64,
a.as_ptr() as *const vmp_pmat_t, a.as_ptr() as *const vmp_pmat_t,
a.rows() as u64, a.rows() as u64,
@@ -591,7 +564,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
mod tests { mod tests {
use crate::{ use crate::{
FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
ZnxLayout, alloc_aligned, alloc_aligned, znx_base::ZnxLayout,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -614,7 +587,7 @@ mod tests {
for row_i in 0..vpmat_rows { for row_i in 0..vpmat_rows {
let mut source: Source = Source::new([0u8; 32]); let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source);
module.vec_znx_dft(&mut a_dft, &a); module.vec_znx_dft(&mut a_dft, 0, &a, 0);
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes); module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
@@ -627,7 +600,7 @@ mod tests {
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes); module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes);
assert_eq!(a_big.raw(), b_big.raw()); assert_eq!(a_big.raw(), b_big.raw());
} }

View File

@@ -1,4 +1,4 @@
use crate::{Backend, Module, VecZnx, ZnxLayout}; use crate::{Backend, Module, VecZnx, znx_base::ZnxLayout};
use rand_distr::{Distribution, Normal}; use rand_distr::{Distribution, Normal};
use sampling::source::Source; use sampling::source::Source;
@@ -106,7 +106,7 @@ impl<B: Backend> Sampling for Module<B> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Sampling; use super::Sampling;
use crate::{FFT64, Module, Stats, VecZnx, ZnxBase, ZnxLayout}; use crate::{FFT64, Module, Stats, VecZnx, VecZnxOps, znx_base::ZnxLayout};
use sampling::source::Source; use sampling::source::Source;
#[test] #[test]
@@ -120,7 +120,7 @@ mod tests {
let zero: Vec<i64> = vec![0; n]; let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287; let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
let mut a: VecZnx = VecZnx::new(&module, cols, size); let mut a: VecZnx = module.new_vec_znx(cols, size);
module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source); module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source);
(0..cols).for_each(|col_j| { (0..cols).for_each(|col_j| {
if col_j != col_i { if col_j != col_i {
@@ -154,7 +154,7 @@ mod tests {
let zero: Vec<i64> = vec![0; n]; let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << log_k as u64) as f64; let k_f64: f64 = (1u64 << log_k as u64) as f64;
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
let mut a: VecZnx = VecZnx::new(&module, cols, size); let mut a: VecZnx = module.new_vec_znx(cols, size);
module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound); module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| { (0..cols).for_each(|col_j| {
if col_j != col_i { if col_j != col_i {

View File

@@ -2,9 +2,8 @@ use std::marker::PhantomData;
use crate::ffi::svp::{self, svp_ppol_t}; use crate::ffi::svp::{self, svp_ppol_t};
use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxLayout, assert_alignement}; use crate::znx_base::{ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement, cast_mut};
use crate::{ZnxInfos, alloc_aligned, cast_mut};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand_core::RngCore; use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex}; use rand_distr::{Distribution, weighted::WeightedIndex};
@@ -118,11 +117,14 @@ impl Scalar {
pub fn as_vec_znx(&self) -> VecZnx { pub fn as_vec_znx(&self) -> VecZnx {
VecZnx { VecZnx {
n: self.n, inner: ZnxBase {
cols: 1, n: self.n,
size: 1, rows: 1,
data: Vec::new(), cols: 1,
ptr: self.ptr, size: 1,
data: Vec::new(),
ptr: self.ptr as *mut u8,
},
} }
} }
} }
@@ -159,7 +161,7 @@ pub struct ScalarZnxDft<B: Backend> {
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb. /// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
impl ScalarZnxDft<FFT64> { impl ScalarZnxDft<FFT64> {
pub fn new(module: &Module<FFT64>) -> Self { pub fn new(module: &Module<FFT64>) -> Self {
module.new_svp_ppol() module.new_scalar_znx_dft()
} }
/// Returns the ring degree of the [SvpPPol]. /// Returns the ring degree of the [SvpPPol].
@@ -168,14 +170,14 @@ impl ScalarZnxDft<FFT64> {
} }
pub fn bytes_of(module: &Module<FFT64>) -> usize { pub fn bytes_of(module: &Module<FFT64>) -> usize {
module.bytes_of_svp_ppol() module.bytes_of_scalar_znx_dft()
} }
pub fn from_bytes(module: &Module<FFT64>, bytes: &mut [u8]) -> Self { pub fn from_bytes(module: &Module<FFT64>, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(bytes.as_ptr()); assert_alignement(bytes.as_ptr());
assert_eq!(bytes.len(), module.bytes_of_svp_ppol()); assert_eq!(bytes.len(), module.bytes_of_scalar_znx_dft());
} }
unsafe { unsafe {
Self { Self {
@@ -191,7 +193,7 @@ impl ScalarZnxDft<FFT64> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol()); assert_eq!(tmp_bytes.len(), module.bytes_of_scalar_znx_dft());
} }
Self { Self {
n: module.n(), n: module.n(),
@@ -209,33 +211,33 @@ impl ScalarZnxDft<FFT64> {
pub trait ScalarZnxDftOps<B: Backend> { pub trait ScalarZnxDftOps<B: Backend> {
/// Allocates a new [SvpPPol]. /// Allocates a new [SvpPPol].
fn new_svp_ppol(&self) -> ScalarZnxDft<B>; fn new_scalar_znx_dft(&self) -> ScalarZnxDft<B>;
/// Returns the minimum number of bytes necessary to allocate /// Returns the minimum number of bytes necessary to allocate
/// a new [SvpPPol] through [SvpPPol::from_bytes] ro. /// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
fn bytes_of_svp_ppol(&self) -> usize; fn bytes_of_scalar_znx_dft(&self) -> usize;
/// Allocates a new [SvpPPol] from an array of bytes. /// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is owned by the [SvpPPol]. /// The array of bytes is owned by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<B>; fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<B>;
/// Allocates a new [SvpPPol] from an array of bytes. /// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is borrowed by the [SvpPPol]. /// The array of bytes is borrowed by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol] /// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<B>; fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<B>;
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft]. /// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft<B>, a: &Scalar); fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft<B>, a: &Scalar);
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
/// the [VecZnxDft] is multiplied with [SvpPPol]. /// the [VecZnxDft] is multiplied with [SvpPPol].
fn svp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &ScalarZnxDft<B>, b: &VecZnx, b_col: usize); fn svp_apply_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &ScalarZnxDft<B>, b: &VecZnx, b_col: usize);
} }
impl ScalarZnxDftOps<FFT64> for Module<FFT64> { impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn new_svp_ppol(&self) -> ScalarZnxDft<FFT64> { fn new_scalar_znx_dft(&self) -> ScalarZnxDft<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_svp_ppol()); let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_scalar_znx_dft());
let ptr: *mut u8 = data.as_mut_ptr(); let ptr: *mut u8 = data.as_mut_ptr();
ScalarZnxDft::<FFT64> { ScalarZnxDft::<FFT64> {
data: data, data: data,
@@ -245,28 +247,28 @@ impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
} }
} }
fn bytes_of_svp_ppol(&self) -> usize { fn bytes_of_scalar_znx_dft(&self) -> usize {
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize } unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
} }
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> { fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
ScalarZnxDft::from_bytes(self, bytes) ScalarZnxDft::from_bytes(self, bytes)
} }
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<FFT64> { fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
ScalarZnxDft::from_bytes_borrow(self, tmp_bytes) ScalarZnxDft::from_bytes_borrow(self, tmp_bytes)
} }
fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft<FFT64>, a: &Scalar) { fn svp_prepare(&self, res: &mut ScalarZnxDft<FFT64>, a: &Scalar) {
unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) } unsafe { svp::svp_prepare(self.ptr, res.ptr as *mut svp_ppol_t, a.as_ptr()) }
} }
fn svp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &ScalarZnxDft<FFT64>, b: &VecZnx, b_col: usize) { fn svp_apply_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &ScalarZnxDft<FFT64>, b: &VecZnx, b_col: usize) {
unsafe { unsafe {
svp::svp_apply_dft( svp::svp_apply_dft(
self.ptr, self.ptr,
c.ptr as *mut vec_znx_dft_t, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
c.size() as u64, res.size() as u64,
a.ptr as *const svp_ppol_t, a.ptr as *const svp_ppol_t,
b.at_ptr(b_col, 0), b.at_ptr(b_col, 0),
b.size() as u64, b.size() as u64,

View File

@@ -1,4 +1,5 @@
use crate::{Encoding, VecZnx, ZnxInfos}; use crate::znx_base::ZnxInfos;
use crate::{Encoding, VecZnx};
use rug::Float; use rug::Float;
use rug::float::Round; use rug::float::Round;
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound}; use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};

View File

@@ -1,12 +1,13 @@
use crate::Backend; use crate::Backend;
use crate::ZnxBase; use crate::Module;
use crate::assert_alignement;
use crate::cast_mut; use crate::cast_mut;
use crate::ffi::znx; use crate::ffi::znx;
use crate::switch_degree; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout};
use crate::{alloc_aligned, assert_alignement};
use std::cmp::min; use std::cmp::min;
pub const VEC_ZNX_ROWS: usize = 1;
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of /// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
/// Zn\[X\] with [i64] coefficients. /// Zn\[X\] with [i64] coefficients.
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array /// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
@@ -17,56 +18,54 @@ use std::cmp::min;
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory /// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci /// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
/// are small polynomials of Zn\[X\]. /// are small polynomials of Zn\[X\].
#[derive(Clone)]
pub struct VecZnx { pub struct VecZnx {
/// Polynomial degree. pub inner: ZnxBase,
pub n: usize,
/// The number of polynomials
pub cols: usize,
/// The number of size per polynomial (a.k.a small polynomials).
pub size: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
pub data: Vec<i64>,
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
pub ptr: *mut i64,
} }
impl ZnxInfos for VecZnx { impl GetZnxBase for VecZnx {
fn n(&self) -> usize { fn znx(&self) -> &ZnxBase {
self.n &self.inner
} }
fn rows(&self) -> usize { fn znx_mut(&mut self) -> &mut ZnxBase {
1 &mut self.inner
} }
}
fn cols(&self) -> usize { impl ZnxInfos for VecZnx {}
self.cols
}
fn size(&self) -> usize { impl ZnxSliceSize for VecZnx {
self.size fn sl(&self) -> usize {
self.cols() * self.n()
} }
} }
impl ZnxLayout for VecZnx { impl ZnxLayout for VecZnx {
type Scalar = i64; type Scalar = i64;
fn as_ptr(&self) -> *const Self::Scalar {
self.ptr
}
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.ptr
}
} }
impl ZnxBasics for VecZnx {} impl ZnxBasics for VecZnx {}
impl<B: Backend> ZnxAlloc<B> for VecZnx {
type Scalar = i64;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
VecZnx {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes),
}
}
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
debug_assert_eq!(
_rows, VEC_ZNX_ROWS,
"rows != {} not supported for VecZnx",
VEC_ZNX_ROWS
);
module.n() * cols * size * size_of::<Self::Scalar>()
}
}
/// Copies the coefficients of `a` on the receiver. /// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays. /// Copy is done with the minimum size matching both backing arrays.
/// Panics if the cols do not match. /// Panics if the cols do not match.
@@ -78,80 +77,6 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) {
data_b[..size].copy_from_slice(&data_a[..size]) data_b[..size].copy_from_slice(&data_a[..size])
} }
impl<B: Backend> ZnxBase<B> for VecZnx {
type Scalar = i64;
/// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\].
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
let n: usize = module.n();
#[cfg(debug_assertions)]
{
assert!(n > 0);
assert!(n & (n - 1) == 0);
assert!(cols > 0);
assert!(size > 0);
}
let mut data: Vec<i64> = alloc_aligned::<i64>(Self::bytes_of(module, cols, size));
let ptr: *mut i64 = data.as_mut_ptr();
Self {
n: n,
cols: cols,
size: size,
data: data,
ptr: ptr,
}
}
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
module.n() * cols * size * size_of::<i64>()
}
/// Returns a new struct implementing [VecZnx] with the provided data as backing array.
///
/// The struct will take ownership of buf[..[Self::bytes_of]]
///
/// User must ensure that data is properly alligned and that
/// the size of data is equal to [Self::bytes_of].
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
let n: usize = module.n();
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
cols: cols,
size: size,
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert!(bytes.len() >= Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
Self {
n: module.n(),
cols: cols,
size: size,
data: Vec::new(),
ptr: bytes.as_mut_ptr() as *mut i64,
}
}
}
impl VecZnx { impl VecZnx {
/// Truncates the precision of the [VecZnx] by k bits. /// Truncates the precision of the [VecZnx] by k bits.
/// ///
@@ -165,11 +90,12 @@ impl VecZnx {
} }
if !self.borrowing() { if !self.borrowing() {
self.data self.inner
.data
.truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); .truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
} }
self.size -= k / log_base2k; self.inner.size -= k / log_base2k;
let k_rem: usize = k % log_base2k; let k_rem: usize = k % log_base2k;
@@ -185,10 +111,6 @@ impl VecZnx {
copy_vec_znx_from(self, a); copy_vec_znx_from(self, a);
} }
pub fn borrowing(&self) -> bool {
self.data.len() == 0
}
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry) normalize(log_base2k, self, carry)
} }

View File

@@ -1,115 +1,71 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::{Backend, FFT64, Module, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement}; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, NTT120};
use std::marker::PhantomData; use std::marker::PhantomData;
const VEC_ZNX_BIG_ROWS: usize = 1;
pub struct VecZnxBig<B: Backend> { pub struct VecZnxBig<B: Backend> {
pub data: Vec<u8>, pub inner: ZnxBase,
pub ptr: *mut u8,
pub n: usize,
pub cols: usize,
pub size: usize,
pub _marker: PhantomData<B>, pub _marker: PhantomData<B>,
} }
impl ZnxBasics for VecZnxBig<FFT64> {} impl<B: Backend> GetZnxBase for VecZnxBig<B> {
fn znx(&self) -> &ZnxBase {
impl<B: Backend> ZnxBase<B> for VecZnxBig<B> { &self.inner
type Scalar = u8;
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
}
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, size));
let ptr: *mut Self::Scalar = data.as_mut_ptr();
Self {
data: data,
ptr: ptr,
n: module.n(),
cols: cols,
size: size,
_marker: PhantomData,
}
} }
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize { fn znx_mut(&mut self) -> &mut ZnxBase {
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } &mut self.inner
}
/// Returns a new [VecZnxBig] with the provided data as backing array.
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr())
};
unsafe {
Self {
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
size: size,
_marker: PhantomData,
}
}
}
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
Self {
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
size: size,
_marker: PhantomData,
}
} }
} }
impl<B: Backend> ZnxInfos for VecZnxBig<B> { impl<B: Backend> ZnxInfos for VecZnxBig<B> {}
fn n(&self) -> usize {
self.n impl<B: Backend> ZnxAlloc<B> for VecZnxBig<B> {
type Scalar = u8;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
VecZnxBig {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes),
_marker: PhantomData,
}
} }
fn cols(&self) -> usize { fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
self.cols debug_assert_eq!(
} _rows, VEC_ZNX_BIG_ROWS,
"rows != {} not supported for VecZnxBig",
fn rows(&self) -> usize { VEC_ZNX_BIG_ROWS
1 );
} unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
fn size(&self) -> usize {
self.size
} }
} }
impl ZnxLayout for VecZnxBig<FFT64> { impl ZnxLayout for VecZnxBig<FFT64> {
type Scalar = i64; type Scalar = i64;
}
fn as_ptr(&self) -> *const Self::Scalar { impl ZnxLayout for VecZnxBig<NTT120> {
self.ptr as *const Self::Scalar type Scalar = i128;
} }
fn as_mut_ptr(&mut self) -> *mut Self::Scalar { impl ZnxBasics for VecZnxBig<FFT64> {}
self.ptr as *mut Self::Scalar
impl ZnxSliceSize for VecZnxBig<FFT64> {
fn sl(&self) -> usize {
self.n()
} }
} }
impl ZnxSliceSize for VecZnxBig<NTT120> {
fn sl(&self) -> usize {
self.n() * 4
}
}
impl ZnxBasics for VecZnxBig<NTT120> {}
impl VecZnxBig<FFT64> { impl VecZnxBig<FFT64> {
pub fn print(&self, n: usize) { pub fn print(&self, n: usize) {
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));

View File

@@ -1,5 +1,6 @@
use crate::ffi::vec_znx; use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement};
pub trait VecZnxBigOps<B: Backend> { pub trait VecZnxBigOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
@@ -17,7 +18,7 @@ pub trait VecZnxBigOps<B: Backend> {
/// ///
/// # Panics /// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big]. /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<B>; fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBig<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array. /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
/// ///
@@ -41,74 +42,74 @@ pub trait VecZnxBigOps<B: Backend> {
fn vec_znx_big_add( fn vec_znx_big_add(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<B>,
col_res: usize, res_col: usize,
a: &VecZnxBig<B>, a: &VecZnxBig<B>,
col_a: usize, a_col: usize,
b: &VecZnxBig<B>, b: &VecZnxBig<B>,
col_b: usize, b_col: usize,
); );
/// Adds `a` to `b` and stores the result on `b`. /// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize); fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
/// Adds `a` to `b` and stores the result on `c`. /// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small( fn vec_znx_big_add_small(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<B>,
col_res: usize, res_col: usize,
a: &VecZnx, a: &VecZnxBig<B>,
col_a: usize, a_col: usize,
b: &VecZnxBig<B>, b: &VecZnx,
col_b: usize, b_col: usize,
); );
/// Adds `a` to `b` and stores the result on `b`. /// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnx, col_a: usize); fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize);
/// Subtracts `a` to `b` and stores the result on `c`. /// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub( fn vec_znx_big_sub(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<B>,
col_res: usize, res_col: usize,
a: &VecZnxBig<B>, a: &VecZnxBig<B>,
col_a: usize, a_col: usize,
b: &VecZnxBig<B>, b: &VecZnxBig<B>,
col_b: usize, b_col: usize,
); );
/// Subtracts `a` to `b` and stores the result on `b`. /// Subtracts `a` to `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize); fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
/// Subtracts `b` to `a` and stores the result on `b`. /// Subtracts `b` to `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize); fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
/// Subtracts `b` to `a` and stores the result on `c`. /// Subtracts `b` to `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a( fn vec_znx_big_sub_small_a(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<B>,
col_res: usize, res_col: usize,
a: &VecZnx, a: &VecZnx,
col_a: usize, a_col: usize,
b: &VecZnxBig<B>, b: &VecZnxBig<B>,
col_b: usize, b_col: usize,
); );
/// Subtracts `a` to `b` and stores the result on `b`. /// Subtracts `a` to `b` and stores the result on `b`.
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnx, col_a: usize); fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize);
/// Subtracts `b` to `a` and stores the result on `c`. /// Subtracts `b` to `a` and stores the result on `c`.
fn vec_znx_big_sub_small_b( fn vec_znx_big_sub_small_b(
&self, &self,
res: &mut VecZnxBig<B>, res: &mut VecZnxBig<B>,
col_res: usize, res_col: usize,
a: &VecZnxBig<B>, a: &VecZnxBig<B>,
col_a: usize, a_col: usize,
b: &VecZnx, b: &VecZnx,
col_b: usize, b_col: usize,
); );
/// Subtracts `b` to `a` and stores the result on `b`. /// Subtracts `b` to `a` and stores the result on `b`.
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnx, col_a: usize); fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize);
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
@@ -123,44 +124,44 @@ pub trait VecZnxBigOps<B: Backend> {
&self, &self,
log_base2k: usize, log_base2k: usize,
res: &mut VecZnx, res: &mut VecZnx,
col_res: usize, res_col: usize,
a: &VecZnxBig<B>, a: &VecZnxBig<B>,
col_a: usize, a_col: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
); );
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize); fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`. /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>, col_a: usize); fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>, a_col: usize);
} }
impl VecZnxBigOps<FFT64> for Module<FFT64> { impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> { fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> {
VecZnxBig::new(self, cols, size) VecZnxBig::new(self, 1, cols, size)
} }
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> { fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes(self, cols, size, bytes) VecZnxBig::from_bytes(self, 1, cols, size, bytes)
} }
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> { fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes) VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes)
} }
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
VecZnxBig::bytes_of(self, cols, size) VecZnxBig::bytes_of(self, 1, cols, size)
} }
fn vec_znx_big_add( fn vec_znx_big_add(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<FFT64>,
col_res: usize, res_col: usize,
a: &VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>,
col_a: usize, a_col: usize,
b: &VecZnxBig<FFT64>, b: &VecZnxBig<FFT64>,
col_b: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -170,36 +171,33 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx::vec_znx_add( vec_znx_big::vec_znx_big_add(
self.ptr, self.ptr,
res.at_mut_ptr(col_res, 0), res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64, res.size() as u64,
res.sl() as u64, a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t,
a.at_ptr(col_a, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64, b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t,
b.at_ptr(col_b, 0),
b.size() as u64, b.size() as u64,
b.sl() as u64,
) )
} }
} }
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) { fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); Self::vec_znx_big_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
} }
} }
fn vec_znx_big_sub( fn vec_znx_big_sub(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<FFT64>,
col_res: usize, res_col: usize,
a: &VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>,
col_a: usize, a_col: usize,
b: &VecZnxBig<FFT64>, b: &VecZnxBig<FFT64>,
col_b: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -209,43 +207,40 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx::vec_znx_sub( vec_znx_big::vec_znx_big_sub(
self.ptr, self.ptr,
res.at_mut_ptr(col_res, 0), res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64, res.size() as u64,
res.sl() as u64, a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t,
a.at_ptr(col_a, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64, b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t,
b.at_ptr(col_b, 0),
b.size() as u64, b.size() as u64,
b.sl() as u64,
) )
} }
} }
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) { fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
} }
} }
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) { fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
} }
} }
fn vec_znx_big_sub_small_b( fn vec_znx_big_sub_small_b(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<FFT64>,
col_res: usize, res_col: usize,
a: &VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>,
col_a: usize, a_col: usize,
b: &VecZnx, b: &VecZnx,
col_b: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -255,36 +250,34 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx::vec_znx_sub( vec_znx_big::vec_znx_big_sub_small_b(
self.ptr, self.ptr,
res.at_mut_ptr(col_res, 0), res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64, res.size() as u64,
res.sl() as u64, a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
a.at_ptr(col_a, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64, b.at_ptr(b_col, 0),
b.at_ptr(col_b, 0),
b.size() as u64, b.size() as u64,
b.sl() as u64, b.sl() as u64,
) )
} }
} }
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnx, col_a: usize) { fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a); Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
} }
} }
fn vec_znx_big_sub_small_a( fn vec_znx_big_sub_small_a(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<FFT64>,
col_res: usize, res_col: usize,
a: &VecZnx, a: &VecZnx,
col_a: usize, a_col: usize,
b: &VecZnxBig<FFT64>, b: &VecZnxBig<FFT64>,
col_b: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -294,36 +287,34 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx::vec_znx_sub( vec_znx_big::vec_znx_big_sub_small_a(
self.ptr, self.ptr,
res.at_mut_ptr(col_res, 0), res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64, res.size() as u64,
res.sl() as u64, a.at_ptr(a_col, 0),
a.at_ptr(col_a, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64, a.sl() as u64,
b.at_ptr(col_b, 0), b.at_ptr(b_col * b.size(), 0) as *const vec_znx_big_t,
b.size() as u64, b.size() as u64,
b.sl() as u64,
) )
} }
} }
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnx, col_a: usize) { fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res); Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
} }
} }
fn vec_znx_big_add_small( fn vec_znx_big_add_small(
&self, &self,
res: &mut VecZnxBig<FFT64>, res: &mut VecZnxBig<FFT64>,
col_res: usize, res_col: usize,
a: &VecZnx, a: &VecZnxBig<FFT64>,
col_a: usize, a_col: usize,
b: &VecZnxBig<FFT64>, b: &VecZnx,
col_b: usize, b_col: usize,
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -333,25 +324,23 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx::vec_znx_add( vec_znx_big::vec_znx_big_add_small(
self.ptr, self.ptr,
res.at_mut_ptr(col_res, 0), res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64, res.size() as u64,
res.sl() as u64, a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
a.at_ptr(col_a, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64, b.at_ptr(b_col, 0),
b.at_ptr(col_b, 0),
b.size() as u64, b.size() as u64,
b.sl() as u64, b.sl() as u64,
) )
} }
} }
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnx, a_col: usize) { fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe { unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>; let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_add_small(self, &mut *res_ptr, col_res, a, a_col, &*res_ptr, col_res); Self::vec_znx_big_add_small(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
} }
} }
@@ -363,9 +352,9 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
&self, &self,
log_base2k: usize, log_base2k: usize,
res: &mut VecZnx, res: &mut VecZnx,
col_res: usize, res_col: usize,
a: &VecZnxBig<FFT64>, a: &VecZnxBig<FFT64>,
col_a: usize, a_col: usize,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -376,44 +365,41 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
} }
unsafe { unsafe {
vec_znx::vec_znx_normalize_base2k( vec_znx_big::vec_znx_big_normalize_base2k(
self.ptr, self.ptr,
log_base2k as u64, log_base2k as u64,
res.at_mut_ptr(col_res, 0), res.at_mut_ptr(res_col, 0),
res.size() as u64, res.size() as u64,
res.sl() as u64, res.sl() as u64,
a.at_ptr(col_a, 0), a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
a.size() as u64, a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
); );
} }
} }
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) { fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), self.n()); assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n()); assert_eq!(res.n(), self.n());
} }
unsafe { unsafe {
vec_znx::vec_znx_automorphism( vec_znx_big::vec_znx_big_automorphism(
self.ptr, self.ptr,
k, k,
res.at_mut_ptr(col_res, 0), res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64, res.size() as u64,
res.sl() as u64, a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
a.at_ptr(col_a, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64,
) )
} }
} }
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<FFT64>, col_a: usize) { fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<FFT64>, a_col: usize) {
unsafe { unsafe {
let a_ptr: *mut VecZnxBig<FFT64> = a as *mut VecZnxBig<FFT64>; let a_ptr: *mut VecZnxBig<FFT64> = a as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a); Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
} }
} }
} }

View File

@@ -1,129 +1,54 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t}; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement}; use crate::{Backend, FFT64, Module, VecZnxBig};
use crate::{VecZnx, alloc_aligned};
use std::marker::PhantomData; use std::marker::PhantomData;
const VEC_ZNX_DFT_ROWS: usize = 1;
pub struct VecZnxDft<B: Backend> { pub struct VecZnxDft<B: Backend> {
pub data: Vec<u8>, inner: ZnxBase,
pub ptr: *mut u8,
pub n: usize,
pub cols: usize,
pub size: usize,
pub _marker: PhantomData<B>, pub _marker: PhantomData<B>,
} }
impl<B: Backend> ZnxBase<B> for VecZnxDft<B> { impl<B: Backend> GetZnxBase for VecZnxDft<B> {
fn znx(&self) -> &ZnxBase {
&self.inner
}
fn znx_mut(&mut self) -> &mut ZnxBase {
&mut self.inner
}
}
impl<B: Backend> ZnxInfos for VecZnxDft<B> {}
impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> {
type Scalar = u8; type Scalar = u8;
fn new(module: &Module<B>, cols: usize, size: usize) -> Self { fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)] VecZnxDft {
{ inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
assert!(cols > 0);
assert!(size > 0);
}
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, size));
let ptr: *mut Self::Scalar = data.as_mut_ptr();
Self {
data: data,
ptr: ptr,
n: module.n(),
size: size,
cols: cols,
_marker: PhantomData, _marker: PhantomData,
} }
} }
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize { fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
unsafe { bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } debug_assert_eq!(
} _rows, VEC_ZNX_DFT_ROWS,
"rows != {} not supported for VecZnxDft",
/// Returns a new [VecZnxDft] with the provided data as backing array. VEC_ZNX_DFT_ROWS
/// User must ensure that data is properly alligned and that );
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft]. unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr())
}
unsafe {
Self {
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
size: size,
_marker: PhantomData,
}
}
}
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
Self {
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
size: size,
_marker: PhantomData,
}
}
}
impl<B: Backend> VecZnxDft<B> {
/// Cast a [VecZnxDft] into a [VecZnxBig].
/// The returned [VecZnxBig] shares the backing array
/// with the original [VecZnxDft].
pub fn as_vec_znx_big(&mut self) -> VecZnxBig<B> {
VecZnxBig::<B> {
data: Vec::new(),
ptr: self.ptr,
n: self.n,
cols: self.cols,
size: self.size,
_marker: PhantomData,
}
}
}
impl<B: Backend> ZnxInfos for VecZnxDft<B> {
fn n(&self) -> usize {
self.n
}
fn rows(&self) -> usize {
1
}
fn cols(&self) -> usize {
self.cols
}
fn size(&self) -> usize {
self.size
} }
} }
impl ZnxLayout for VecZnxDft<FFT64> { impl ZnxLayout for VecZnxDft<FFT64> {
type Scalar = f64; type Scalar = f64;
}
fn as_ptr(&self) -> *const Self::Scalar { impl ZnxSliceSize for VecZnxDft<FFT64> {
self.ptr as *const Self::Scalar fn sl(&self) -> usize {
} self.n()
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.ptr as *mut Self::Scalar
} }
} }
@@ -133,225 +58,21 @@ impl VecZnxDft<FFT64> {
} }
} }
pub trait VecZnxDftOps<B: Backend> { impl<B: Backend> VecZnxDft<B> {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. /// Cast a [VecZnxDft] into a [VecZnxBig].
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<B>; /// The returned [VecZnxBig] shares the backing array
/// with the original [VecZnxDft].
/// Returns a new [VecZnxDft] with the provided bytes array as backing array. pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> {
/// VecZnxBig::<B> {
/// Behavior: takes ownership of the backing array. inner: ZnxBase {
/// data: Vec::new(),
/// # Arguments ptr: self.ptr(),
/// n: self.n(),
/// * `cols`: the number of cols of the [VecZnxDft]. rows: self.rows(),
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. cols: self.cols(),
/// size: self.size(),
/// # Panics },
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. _marker: PhantomData,
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig<B>, a: &mut VecZnxDft<B>);
fn vec_znx_idft(&self, b: &mut VecZnxBig<B>, a: &VecZnxDft<B>, tmp_bytes: &mut [u8]);
fn vec_znx_dft(&self, b: &mut VecZnxDft<B>, a: &VecZnx);
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft<B>, a: &VecZnxDft<B>);
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft<B>, tmp_bytes: &mut [u8]);
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
}
impl VecZnxDftOps<FFT64> for Module<FFT64> {
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<FFT64> {
VecZnxDft::<FFT64>::new(&self, cols, size)
}
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes(self, cols, size, tmp_bytes)
}
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes_borrow(self, cols, size, tmp_bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
VecZnxDft::bytes_of(&self, cols, size)
}
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig<FFT64>, a: &mut VecZnxDft<FFT64>) {
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
a.ptr as *mut vec_znx_dft_t,
a.poly_count() as u64,
)
} }
} }
fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
}
/// b <- DFT(a)
///
/// # Panics
/// If b.cols < a_cols
fn vec_znx_dft(&self, b: &mut VecZnxDft<FFT64>, a: &VecZnx) {
unsafe {
vec_znx_dft::vec_znx_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
b.size() as u64,
a.as_ptr(),
a.size() as u64,
(a.n() * a.cols()) as u64,
)
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_idft_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_dft::vec_znx_idft(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
a.ptr as *const vec_znx_dft_t,
a.poly_count() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>) {
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
b.ptr as *mut vec_znx_dft_t,
b.poly_count() as u64,
a.ptr as *const vec_znx_dft_t,
a.poly_count() as u64,
[0u8; 0].as_mut_ptr(),
);
}
}
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_dft_automorphism_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr())
}
println!("{}", a.poly_count());
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
a.ptr as *mut vec_znx_dft_t,
a.poly_count() as u64,
a.ptr as *const vec_znx_dft_t,
a.poly_count() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize }
}
}
#[cfg(test)]
mod tests {
use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxLayout, alloc_aligned};
use itertools::izip;
use sampling::source::Source;
#[test]
fn test_automorphism_dft() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let size: usize = 2;
let log_base2k: usize = 17;
let mut a: VecZnx = module.new_vec_znx(1, size);
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, 0, size, &mut source);
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
let p: i64 = -5;
// a_dft <- DFT(a)
module.vec_znx_dft(&mut a_dft, &a);
// a_dft <- AUTO(a_dft)
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
// a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a, 0);
// b_dft <- DFT(AUTO(a))
module.vec_znx_dft(&mut b_dft, &a);
let a_f64: &[f64] = a_dft.raw();
let b_f64: &[f64] = b_dft.raw();
izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| {
assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs());
});
module.free()
}
} }

View File

@@ -0,0 +1,140 @@
use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxAlloc;
use crate::znx_base::ZnxInfos;
use crate::znx_base::ZnxLayout;
use crate::znx_base::ZnxSliceSize;
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, assert_alignement};
pub trait VecZnxDftOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &mut VecZnxDft<B>, a_cols: usize);
fn vec_znx_idft(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxDft<B>, a_col: usize, tmp_bytes: &mut [u8]);
fn vec_znx_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &VecZnx, a_col: usize);
}
impl VecZnxDftOps<FFT64> for Module<FFT64> {
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<FFT64> {
VecZnxDft::<FFT64>::new(&self, 1, cols, size)
}
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes(self, 1, cols, size, bytes)
}
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
VecZnxDft::bytes_of(&self, 1, cols, size)
}
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &mut VecZnxDft<FFT64>, a_col: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(res.poly_count(), a.poly_count());
}
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t,
res.size() as u64,
a.at_ptr(a_col * a.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t,
a.size() as u64,
)
}
}
fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
}
/// b <- DFT(a)
///
/// # Panics
/// If b.cols < a_cols
fn vec_znx_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe {
vec_znx_dft::vec_znx_dft(
self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t,
res.size() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxDft<FFT64>, a_col: usize, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_idft_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_dft::vec_znx_idft(
self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t,
res.size() as u64,
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_dft::vec_znx_dft_t,
a.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
}

View File

@@ -1,5 +1,6 @@
use crate::ffi::vec_znx; use crate::ffi::vec_znx;
use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement, switch_degree}; use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
use crate::{Backend, Module, VEC_ZNX_ROWS, VecZnx, assert_alignement};
pub trait VecZnxOps { pub trait VecZnxOps {
/// Allocates a new [VecZnx]. /// Allocates a new [VecZnx].
/// ///
@@ -19,7 +20,7 @@ pub trait VecZnxOps {
/// ///
/// # Panic /// # Panic
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx; fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnx;
/// Instantiates a new [VecZnx] from a slice of bytes. /// Instantiates a new [VecZnx] from a slice of bytes.
/// The returned [VecZnx] does take ownership of the slice of bytes. /// The returned [VecZnx] does take ownership of the slice of bytes.
@@ -107,19 +108,19 @@ pub trait VecZnxOps {
impl<B: Backend> VecZnxOps for Module<B> { impl<B: Backend> VecZnxOps for Module<B> {
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx { fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx {
VecZnx::new(self, cols, size) VecZnx::new(self, VEC_ZNX_ROWS, cols, size)
} }
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize { fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
VecZnx::bytes_of(self, cols, size) VecZnx::bytes_of(self, VEC_ZNX_ROWS, cols, size)
} }
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnx {
VecZnx::from_bytes(self, cols, size, bytes) VecZnx::from_bytes(self, VEC_ZNX_ROWS, cols, size, bytes)
} }
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx { fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes) VecZnx::from_bytes_borrow(self, VEC_ZNX_ROWS, cols, size, tmp_bytes)
} }
fn vec_znx_normalize_tmp_bytes(&self) -> usize { fn vec_znx_normalize_tmp_bytes(&self) -> usize {

View File

@@ -1,10 +1,37 @@
use crate::{Backend, Module, assert_alignement, cast_mut}; use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut};
use itertools::izip; use itertools::izip;
use std::cmp::min; use std::cmp::min;
pub trait ZnxInfos { pub struct ZnxBase {
/// The ring degree
pub n: usize,
/// The number of rows (in the third dimension)
pub rows: usize,
/// The number of polynomials
pub cols: usize,
/// The number of size per polynomial (a.k.a small polynomials).
pub size: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
pub data: Vec<u8>,
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
pub ptr: *mut u8,
}
pub trait GetZnxBase {
fn znx(&self) -> &ZnxBase;
fn znx_mut(&mut self) -> &mut ZnxBase;
}
pub trait ZnxInfos: GetZnxBase {
/// Returns the ring degree of the polynomials. /// Returns the ring degree of the polynomials.
fn n(&self) -> usize; fn n(&self) -> usize {
self.znx().n
}
/// Returns the base two logarithm of the ring dimension of the polynomials. /// Returns the base two logarithm of the ring dimension of the polynomials.
fn log_n(&self) -> usize { fn log_n(&self) -> usize {
@@ -12,41 +39,104 @@ pub trait ZnxInfos {
} }
/// Returns the number of rows. /// Returns the number of rows.
fn rows(&self) -> usize; fn rows(&self) -> usize {
self.znx().rows
}
/// Returns the number of polynomials in each row. /// Returns the number of polynomials in each row.
fn cols(&self) -> usize; fn cols(&self) -> usize {
self.znx().cols
}
/// Returns the number of size per polynomial. /// Returns the number of size per polynomial.
fn size(&self) -> usize; fn size(&self) -> usize {
self.znx().size
}
fn data(&self) -> &[u8] {
&self.znx().data
}
fn ptr(&self) -> *mut u8 {
self.znx().ptr
}
/// Returns the total number of small polynomials. /// Returns the total number of small polynomials.
fn poly_count(&self) -> usize { fn poly_count(&self) -> usize {
self.rows() * self.cols() * self.size() self.rows() * self.cols() * self.size()
} }
}
pub trait ZnxSliceSize {
/// Returns the slice size, which is the offset between /// Returns the slice size, which is the offset between
/// two size of the same column. /// two size of the same column.
fn sl(&self) -> usize { fn sl(&self) -> usize;
self.n() * self.cols() }
impl ZnxBase {
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
res.data = bytes;
res
}
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(n & (n - 1), 0, "n must be a power of two");
assert!(n > 0, "n must be greater than 0");
assert!(rows > 0, "rows must be greater than 0");
assert!(cols > 0, "cols must be greater than 0");
assert!(size > 0, "size must be greater than 0");
}
Self {
n: n,
rows: rows,
cols: cols,
size: size,
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
}
} }
} }
pub trait ZnxBase<B: Backend> { pub trait ZnxAlloc<B: Backend>
where
Self: Sized + ZnxInfos,
{
type Scalar; type Scalar;
fn new(module: &Module<B>, cols: usize, size: usize) -> Self; fn new(module: &Module<B>, rows: usize, cols: usize, size: usize) -> Self {
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self; let bytes: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(module, rows, cols, size));
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self; Self::from_bytes(module, rows, cols, size, bytes)
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize; }
fn from_bytes(module: &Module<B>, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes);
res.znx_mut().data = bytes;
res
}
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize;
} }
pub trait ZnxLayout: ZnxInfos { pub trait ZnxLayout: ZnxInfos {
type Scalar; type Scalar;
/// Returns true if the receiver is only borrowing the data.
fn borrowing(&self) -> bool {
self.znx().data.len() == 0
}
/// Returns a non-mutable pointer to the underlying coefficients array. /// Returns a non-mutable pointer to the underlying coefficients array.
fn as_ptr(&self) -> *const Self::Scalar; fn as_ptr(&self) -> *const Self::Scalar {
self.znx().ptr as *const Self::Scalar
}
/// Returns a mutable pointer to the underlying coefficients array. /// Returns a mutable pointer to the underlying coefficients array.
fn as_mut_ptr(&mut self) -> *mut Self::Scalar; fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.znx_mut().ptr as *mut Self::Scalar
}
/// Returns a non-mutable reference to the entire underlying coefficient array. /// Returns a non-mutable reference to the entire underlying coefficient array.
fn raw(&self) -> &[Self::Scalar] { fn raw(&self) -> &[Self::Scalar] {

View File

@@ -1,5 +1,3 @@
cargo-features = ["edition2024"]
[package] [package]
name = "rlwe" name = "rlwe"
version = "0.1.0" version = "0.1.0"

View File

@@ -20,7 +20,7 @@ pub struct AutomorphismKey {
} }
pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize { pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q) module.bytes_of_scalar() + module.bytes_of_scalar_znx_dft() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
} }
impl Parameters { impl Parameters {
@@ -103,10 +103,10 @@ impl AutomorphismKey {
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) -> Vec<Self> { ) -> Vec<Self> {
let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar()); let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar());
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol()); let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar_znx_dft());
let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes); let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes);
let mut sk_out: ScalarZnxDft = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes); let mut sk_out: ScalarZnxDft = module.new_scalar_znx_dft_from_bytes_borrow(sk_out_bytes);
let mut keys: Vec<AutomorphismKey> = Vec::new(); let mut keys: Vec<AutomorphismKey> = Vec::new();
@@ -116,7 +116,7 @@ impl AutomorphismKey {
let p_inv: i64 = module.galois_element_inv(*pi); let p_inv: i64 = module.galois_element_inv(*pi);
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx()); module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
module.svp_prepare(&mut sk_out, &sk_auto); module.scalar_znx_dft_prepare(&mut sk_out, &sk_auto);
encrypt_grlwe_sk( encrypt_grlwe_sk(
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes, module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
); );

View File

@@ -20,7 +20,7 @@ impl SecretKey {
} }
pub fn prepare(&self, module: &Module, sk_ppol: &mut ScalarZnxDft) { pub fn prepare(&self, module: &Module, sk_ppol: &mut ScalarZnxDft) {
module.svp_prepare(sk_ppol, &self.0) module.scalar_znx_dft_prepare(sk_ppol, &self.0)
} }
} }