wip major refactoring (compiles & all test + example passing)

This commit is contained in:
Jean-Philippe Bossuat
2025-04-30 13:43:18 +02:00
parent 2cc51eee18
commit 6f7b93c7ca
18 changed files with 662 additions and 870 deletions

View File

@@ -23,10 +23,10 @@ fn main() {
s.fill_ternary_prob(0.5, &mut source);
// Buffer to store s in the DFT domain
let mut s_ppol: ScalarZnxDft<FFT64> = module.new_svp_ppol();
let mut s_dft: ScalarZnxDft<FFT64> = module.new_scalar_znx_dft();
// s_ppol <- DFT(s)
module.svp_prepare(&mut s_ppol, &s);
// s_dft <- DFT(s)
module.svp_prepare(&mut s_dft, &s);
// Allocates a VecZnx with two columns: ct=(0, 0)
let mut ct: VecZnx = module.new_vec_znx(
@@ -46,16 +46,17 @@ fn main() {
// Applies DFT(ct[1]) * DFT(s)
module.svp_apply_dft(
&mut buf_dft, // DFT(ct[1] * s)
&s_ppol, // DFT(s)
0, // Selects the first column of res
&s_dft, // DFT(s)
&ct,
1, // Selects the second column of ct
);
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
let mut buf_big: VecZnxBig<FFT64> = buf_dft.as_vec_znx_big();
let mut buf_big: VecZnxBig<FFT64> = buf_dft.alias_as_vec_znx_big();
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
// Creates a plaintext: VecZnx with 1 column
let mut m: VecZnx = module.new_vec_znx(
@@ -103,13 +104,14 @@ fn main() {
// DFT(ct[1] * s)
module.svp_apply_dft(
&mut buf_dft,
&s_ppol,
0, // Selects the first column of res.
&s_dft,
&ct,
1, // Selects the second column of ct (ct[1])
);
// BIG(c1 * s) = IDFT(DFT(c1 * s))
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
// BIG(c1 * s) + ct[0]
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);

View File

@@ -42,8 +42,8 @@ fn main() {
let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs_mat);
module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf);
let mut c_big: VecZnxBig<FFT64> = c_dft.as_vec_znx_big();
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft);
let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
module.vec_znx_idft_tmp_a(&mut c_big, 0, &mut c_dft, 0);
let mut res: VecZnx = module.new_vec_znx(1, limbs_vec);
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf);

View File

@@ -1,5 +1,6 @@
use crate::ffi::znx::znx_zero_i64_ref;
use crate::{VecZnx, ZnxInfos, ZnxLayout};
use crate::znx_base::ZnxLayout;
use crate::{VecZnx, znx_base::ZnxInfos};
use itertools::izip;
use rug::{Assign, Float};
use std::cmp::min;
@@ -262,7 +263,10 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i
#[cfg(test)]
mod tests {
use crate::{Encoding, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout};
use crate::{
Encoding, FFT64, Module, VecZnx, VecZnxOps,
znx_base::{ZnxInfos, ZnxLayout},
};
use itertools::izip;
use sampling::source::Source;
@@ -273,7 +277,7 @@ mod tests {
let log_base2k: usize = 17;
let size: usize = 5;
let log_k: usize = size * log_base2k - 5;
let mut a: VecZnx = VecZnx::new(&module, 2, size);
let mut a: VecZnx = module.new_vec_znx(2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
@@ -295,7 +299,7 @@ mod tests {
let log_base2k: usize = 17;
let size: usize = 5;
let log_k: usize = size * log_base2k - 5;
let mut a: VecZnx = VecZnx::new(&module, 2, size);
let mut a: VecZnx = module.new_vec_znx(2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);

View File

@@ -1,4 +1,3 @@
pub mod commons;
pub mod encoding;
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
// Other modules and exports
@@ -12,9 +11,10 @@ pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_big_ops;
pub mod vec_znx_dft;
pub mod vec_znx_dft_ops;
pub mod vec_znx_ops;
pub mod znx_base;
pub use commons::*;
pub use encoding::*;
pub use mat_znx_dft::*;
pub use module::*;
@@ -26,7 +26,9 @@ pub use vec_znx::*;
pub use vec_znx_big::*;
pub use vec_znx_big_ops::*;
pub use vec_znx_dft::*;
pub use vec_znx_dft_ops::*;
pub use vec_znx_ops::*;
pub use znx_base::*;
pub const GALOISGENERATOR: u64 = 5;
pub const DEFAULTALIGN: usize = 64;
@@ -110,14 +112,8 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
unsafe { Vec::from_raw_parts(ptr, len, cap) }
}
// Allocates an aligned of size equal to the smallest power of two equal or greater to `size` that is
// at least as bit as DEFAULTALIGN / std::mem::size_of::<T>().
/// Allocates an aligned of size equal to the smallest multiple
/// of [DEFAULTALIGN] that is equal or greater to `size`.
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(
std::cmp::max(
size.next_power_of_two(),
DEFAULTALIGN / std::mem::size_of::<T>(),
),
DEFAULTALIGN,
)
alloc_aligned_custom::<T>(size + (size % DEFAULTALIGN), DEFAULTALIGN)
}

View File

@@ -1,103 +1,75 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp::{self, vmp_pmat_t};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement};
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement};
use std::marker::PhantomData;
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
/// Each col of the [VmpPMat] can be seen as a collection of [VecZnxDft].
/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft].
///
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat].
/// See the trait [VmpPMatOps] for additional information.
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
/// See the trait [MatZnxDftOps] for additional information.
pub struct MatZnxDft<B: Backend> {
/// Raw data, is empty if borrowing scratch space.
data: Vec<u8>,
/// Pointer to data. Can point to scratch space.
ptr: *mut u8,
/// The ring degree of each polynomial.
n: usize,
/// Number of rows
rows: usize,
/// Number of cols
cols: usize,
/// The number of small polynomials
size: usize,
pub inner: ZnxBase,
_marker: PhantomData<B>,
}
impl<B: Backend> ZnxInfos for MatZnxDft<B> {
fn n(&self) -> usize {
self.n
impl<B: Backend> GetZnxBase for MatZnxDft<B> {
fn znx(&self) -> &ZnxBase {
&self.inner
}
fn rows(&self) -> usize {
self.rows
}
fn cols(&self) -> usize {
self.cols
}
fn size(&self) -> usize {
self.size
fn znx_mut(&mut self) -> &mut ZnxBase {
&mut self.inner
}
}
impl MatZnxDft<FFT64> {
fn new(module: &Module<FFT64>, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, size));
let ptr: *mut u8 = data.as_mut_ptr();
MatZnxDft::<FFT64> {
data: data,
ptr: ptr,
n: module.n(),
rows: rows,
cols: cols,
size: size,
impl<B: Backend> ZnxInfos for MatZnxDft<B> {}
impl ZnxSliceSize for MatZnxDft<FFT64> {
fn sl(&self) -> usize {
self.n()
}
}
impl ZnxLayout for MatZnxDft<FFT64> {
type Scalar = f64;
}
impl<B: Backend> ZnxAlloc<B> for MatZnxDft<B> {
type Scalar = u8;
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
Self {
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes),
_marker: PhantomData,
}
}
pub fn as_ptr(&self) -> *const u8 {
self.ptr
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize {
unsafe { vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols }
}
}
pub fn as_mut_ptr(&self) -> *mut u8 {
self.ptr
}
pub fn borrowed(&self) -> bool {
self.data.len() == 0
}
/// Returns a non-mutable reference to the entire contiguous array of the [VmpPMat].
pub fn raw(&self) -> &[f64] {
let ptr: *const f64 = self.ptr as *const f64;
let size: usize = self.n() * self.poly_count();
unsafe { &std::slice::from_raw_parts(ptr, size) }
}
/// Returns a mutable reference of to the entire contiguous array of the [VmpPMat].
pub fn raw_mut(&self) -> &mut [f64] {
let ptr: *mut f64 = self.ptr as *mut f64;
let size: usize = self.n() * self.poly_count();
unsafe { std::slice::from_raw_parts_mut(ptr, size) }
}
/// Returns a copy of the backend array at index (i, j) of the [VmpPMat].
impl MatZnxDft<FFT64> {
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
///
/// # Arguments
///
/// * `row`: row index (i).
/// * `col`: col index (j).
pub fn at(&self, row: usize, col: usize) -> Vec<f64> {
let mut res: Vec<f64> = alloc_aligned(self.n);
#[allow(dead_code)]
fn at(&self, row: usize, col: usize) -> Vec<f64> {
let n: usize = self.n();
if self.n < 8 {
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)]);
let mut res: Vec<f64> = alloc_aligned(n);
if n < 8 {
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]);
} else {
(0..self.n >> 3).for_each(|blk| {
(0..n >> 3).for_each(|blk| {
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
});
}
@@ -105,6 +77,7 @@ impl MatZnxDft<FFT64> {
res
}
#[allow(dead_code)]
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
let nrows: usize = self.rows();
let nsize: usize = self.size();
@@ -117,11 +90,11 @@ impl MatZnxDft<FFT64> {
}
/// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [VmpPMat].
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
pub trait MatZnxDftOps<B: Backend> {
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize;
/// Allocates a new [VmpPMat] with the given number of rows and columns.
/// Allocates a new [MatZnxDft] with the given number of rows and columns.
///
/// # Arguments
///
@@ -129,83 +102,83 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `size`: number of size (number of size of each [VecZnxDft]).
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>;
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
/// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous].
///
/// # Arguments
///
/// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous].
/// * `size`: number of size of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous].
/// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
/// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize;
/// Prepares a [VmpPMat] from a contiguous array of [i64].
/// Prepares a [MatZnxDft] from a contiguous array of [i64].
/// The helper struct [Matrix3D] can be used to contruct and populate
/// the appropriate contiguous array.
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat].
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft].
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft<B>, a: &[i64], buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a [VecZnx].
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat].
/// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft].
/// * `row_i`: the index of the row to prepare.
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row(&self, b: &mut MatZnxDft<B>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxBig].
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
///
/// # Arguments
///
/// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat].
/// * `a`: [VmpPMat] on which the values are encoded.
/// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row(&self, b: &mut VecZnxBig<B>, a: &MatZnxDft<B>, row_i: usize);
/// Prepares the ith-row of [VmpPMat] from a [VecZnxDft].
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [VmpPMat].
/// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
/// * `row_i`: the index of the row to prepare.
///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<B>, a: &VecZnxDft<B>, row_i: usize);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxDft].
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat].
/// * `a`: [VmpPMat] on which the values are encoded.
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, row_i: usize);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
///
/// # Arguments
///
/// * `c_size`: number of size of the output [VecZnxDft].
/// * `a_size`: number of size of the input [VecZnx].
/// * `rows`: number of rows of the input [VmpPMat].
/// * `size`: number of size of the input [VmpPMat].
/// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -221,17 +194,17 @@ pub trait MatZnxDftOps<B: Backend> {
///
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver.
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver.
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -247,28 +220,28 @@ pub trait MatZnxDftOps<B: Backend> {
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
///
/// # Arguments
///
/// * `c_size`: number of size of the output [VecZnxDft].
/// * `a_size`: number of size of the input [VecZnxDft].
/// * `rows`: number of rows of the input [VmpPMat].
/// * `size`: number of size of the input [VmpPMat].
/// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -284,18 +257,18 @@ pub trait MatZnxDftOps<B: Backend> {
///
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it.
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -311,18 +284,18 @@ pub trait MatZnxDftOps<B: Backend> {
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place.
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
@@ -337,8 +310,8 @@ pub trait MatZnxDftOps<B: Backend> {
/// # Arguments
///
/// * `b`: the input and output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
/// * `a`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, buf: &mut [u8]);
}
@@ -404,7 +377,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe {
vmp::vmp_extract_row(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.as_mut_ptr() as *mut vec_znx_big_t,
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
@@ -423,7 +396,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
vmp::vmp_prepare_row_dft(
self.ptr,
b.as_mut_ptr() as *mut vmp_pmat_t,
a.ptr as *const vec_znx_dft_t,
a.as_ptr() as *const vec_znx_dft_t,
row_i as u64,
b.rows() as u64,
b.size() as u64,
@@ -440,7 +413,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe {
vmp::vmp_extract_row_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
b.as_mut_ptr() as *mut vec_znx_dft_t,
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
@@ -470,7 +443,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe {
vmp::vmp_apply_dft(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64,
a.as_ptr(),
a.size() as u64,
@@ -492,7 +465,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe {
vmp::vmp_apply_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64,
a.as_ptr(),
a.size() as u64,
@@ -526,9 +499,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe {
vmp::vmp_apply_dft_to_dft(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64,
a.ptr as *const vec_znx_dft_t,
a.as_ptr() as *const vec_znx_dft_t,
a.size() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
@@ -553,9 +526,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe {
vmp::vmp_apply_dft_to_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64,
a.ptr as *const vec_znx_dft_t,
a.as_ptr() as *const vec_znx_dft_t,
a.size() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
@@ -574,9 +547,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
unsafe {
vmp::vmp_apply_dft_to_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
b.as_mut_ptr() as *mut vec_znx_dft_t,
b.size() as u64,
b.ptr as *mut vec_znx_dft_t,
b.as_ptr() as *mut vec_znx_dft_t,
b.size() as u64,
a.as_ptr() as *const vmp_pmat_t,
a.rows() as u64,
@@ -591,7 +564,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
mod tests {
use crate::{
FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
ZnxLayout, alloc_aligned,
alloc_aligned, znx_base::ZnxLayout,
};
use sampling::source::Source;
@@ -614,7 +587,7 @@ mod tests {
for row_i in 0..vpmat_rows {
let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source);
module.vec_znx_dft(&mut a_dft, &a);
module.vec_znx_dft(&mut a_dft, 0, &a, 0);
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
@@ -627,7 +600,7 @@ mod tests {
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes);
module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes);
assert_eq!(a_big.raw(), b_big.raw());
}

View File

@@ -1,4 +1,4 @@
use crate::{Backend, Module, VecZnx, ZnxLayout};
use crate::{Backend, Module, VecZnx, znx_base::ZnxLayout};
use rand_distr::{Distribution, Normal};
use sampling::source::Source;
@@ -106,7 +106,7 @@ impl<B: Backend> Sampling for Module<B> {
#[cfg(test)]
mod tests {
use super::Sampling;
use crate::{FFT64, Module, Stats, VecZnx, ZnxBase, ZnxLayout};
use crate::{FFT64, Module, Stats, VecZnx, VecZnxOps, znx_base::ZnxLayout};
use sampling::source::Source;
#[test]
@@ -120,7 +120,7 @@ mod tests {
let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx = VecZnx::new(&module, cols, size);
let mut a: VecZnx = module.new_vec_znx(cols, size);
module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
@@ -154,7 +154,7 @@ mod tests {
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << log_k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx = VecZnx::new(&module, cols, size);
let mut a: VecZnx = module.new_vec_znx(cols, size);
module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {

View File

@@ -2,9 +2,8 @@ use std::marker::PhantomData;
use crate::ffi::svp::{self, svp_ppol_t};
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxLayout, assert_alignement};
use crate::{ZnxInfos, alloc_aligned, cast_mut};
use crate::znx_base::{ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement, cast_mut};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
@@ -118,11 +117,14 @@ impl Scalar {
pub fn as_vec_znx(&self) -> VecZnx {
VecZnx {
inner: ZnxBase {
n: self.n,
rows: 1,
cols: 1,
size: 1,
data: Vec::new(),
ptr: self.ptr,
ptr: self.ptr as *mut u8,
},
}
}
}
@@ -159,7 +161,7 @@ pub struct ScalarZnxDft<B: Backend> {
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
impl ScalarZnxDft<FFT64> {
pub fn new(module: &Module<FFT64>) -> Self {
module.new_svp_ppol()
module.new_scalar_znx_dft()
}
/// Returns the ring degree of the [SvpPPol].
@@ -168,14 +170,14 @@ impl ScalarZnxDft<FFT64> {
}
pub fn bytes_of(module: &Module<FFT64>) -> usize {
module.bytes_of_svp_ppol()
module.bytes_of_scalar_znx_dft()
}
pub fn from_bytes(module: &Module<FFT64>, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr());
assert_eq!(bytes.len(), module.bytes_of_svp_ppol());
assert_eq!(bytes.len(), module.bytes_of_scalar_znx_dft());
}
unsafe {
Self {
@@ -191,7 +193,7 @@ impl ScalarZnxDft<FFT64> {
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol());
assert_eq!(tmp_bytes.len(), module.bytes_of_scalar_znx_dft());
}
Self {
n: module.n(),
@@ -209,33 +211,33 @@ impl ScalarZnxDft<FFT64> {
pub trait ScalarZnxDftOps<B: Backend> {
/// Allocates a new [SvpPPol].
fn new_svp_ppol(&self) -> ScalarZnxDft<B>;
fn new_scalar_znx_dft(&self) -> ScalarZnxDft<B>;
/// Returns the minimum number of bytes necessary to allocate
/// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
fn bytes_of_svp_ppol(&self) -> usize;
fn bytes_of_scalar_znx_dft(&self) -> usize;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is owned by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<B>;
fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<B>;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is borrowed by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<B>;
fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<B>;
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft<B>, a: &Scalar);
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
/// the [VecZnxDft] is multiplied with [SvpPPol].
fn svp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &ScalarZnxDft<B>, b: &VecZnx, b_col: usize);
fn svp_apply_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &ScalarZnxDft<B>, b: &VecZnx, b_col: usize);
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn new_svp_ppol(&self) -> ScalarZnxDft<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_svp_ppol());
fn new_scalar_znx_dft(&self) -> ScalarZnxDft<FFT64> {
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_scalar_znx_dft());
let ptr: *mut u8 = data.as_mut_ptr();
ScalarZnxDft::<FFT64> {
data: data,
@@ -245,28 +247,28 @@ impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn bytes_of_svp_ppol(&self) -> usize {
fn bytes_of_scalar_znx_dft(&self) -> usize {
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
}
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
ScalarZnxDft::from_bytes(self, bytes)
}
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
ScalarZnxDft::from_bytes_borrow(self, tmp_bytes)
}
fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft<FFT64>, a: &Scalar) {
unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) }
fn svp_prepare(&self, res: &mut ScalarZnxDft<FFT64>, a: &Scalar) {
unsafe { svp::svp_prepare(self.ptr, res.ptr as *mut svp_ppol_t, a.as_ptr()) }
}
fn svp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &ScalarZnxDft<FFT64>, b: &VecZnx, b_col: usize) {
fn svp_apply_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &ScalarZnxDft<FFT64>, b: &VecZnx, b_col: usize) {
unsafe {
svp::svp_apply_dft(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.size() as u64,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
a.ptr as *const svp_ppol_t,
b.at_ptr(b_col, 0),
b.size() as u64,

View File

@@ -1,4 +1,5 @@
use crate::{Encoding, VecZnx, ZnxInfos};
use crate::znx_base::ZnxInfos;
use crate::{Encoding, VecZnx};
use rug::Float;
use rug::float::Round;
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};

View File

@@ -1,12 +1,13 @@
use crate::Backend;
use crate::ZnxBase;
use crate::Module;
use crate::assert_alignement;
use crate::cast_mut;
use crate::ffi::znx;
use crate::switch_degree;
use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout};
use crate::{alloc_aligned, assert_alignement};
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
use std::cmp::min;
pub const VEC_ZNX_ROWS: usize = 1;
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
/// Zn\[X\] with [i64] coefficients.
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
@@ -17,56 +18,54 @@ use std::cmp::min;
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
/// are small polynomials of Zn\[X\].
#[derive(Clone)]
pub struct VecZnx {
/// Polynomial degree.
pub n: usize,
/// The number of polynomials
pub cols: usize,
/// The number of size per polynomial (a.k.a small polynomials).
pub size: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
pub data: Vec<i64>,
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
pub ptr: *mut i64,
pub inner: ZnxBase,
}
impl ZnxInfos for VecZnx {
fn n(&self) -> usize {
self.n
impl GetZnxBase for VecZnx {
fn znx(&self) -> &ZnxBase {
&self.inner
}
fn rows(&self) -> usize {
1
fn znx_mut(&mut self) -> &mut ZnxBase {
&mut self.inner
}
}
fn cols(&self) -> usize {
self.cols
}
impl ZnxInfos for VecZnx {}
fn size(&self) -> usize {
self.size
impl ZnxSliceSize for VecZnx {
fn sl(&self) -> usize {
self.cols() * self.n()
}
}
impl ZnxLayout for VecZnx {
type Scalar = i64;
fn as_ptr(&self) -> *const Self::Scalar {
self.ptr
}
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.ptr
}
}
impl ZnxBasics for VecZnx {}
impl<B: Backend> ZnxAlloc<B> for VecZnx {
type Scalar = i64;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
VecZnx {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes),
}
}
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
debug_assert_eq!(
_rows, VEC_ZNX_ROWS,
"rows != {} not supported for VecZnx",
VEC_ZNX_ROWS
);
module.n() * cols * size * size_of::<Self::Scalar>()
}
}
/// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays.
/// Panics if the cols do not match.
@@ -78,80 +77,6 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) {
data_b[..size].copy_from_slice(&data_a[..size])
}
impl<B: Backend> ZnxBase<B> for VecZnx {
type Scalar = i64;
/// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\].
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
let n: usize = module.n();
#[cfg(debug_assertions)]
{
assert!(n > 0);
assert!(n & (n - 1) == 0);
assert!(cols > 0);
assert!(size > 0);
}
let mut data: Vec<i64> = alloc_aligned::<i64>(Self::bytes_of(module, cols, size));
let ptr: *mut i64 = data.as_mut_ptr();
Self {
n: n,
cols: cols,
size: size,
data: data,
ptr: ptr,
}
}
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
module.n() * cols * size * size_of::<i64>()
}
/// Returns a new struct implementing [VecZnx] with the provided data as backing array.
///
/// The struct will take ownership of buf[..[Self::bytes_of]]
///
/// User must ensure that data is properly alligned and that
/// the size of data is equal to [Self::bytes_of].
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
let n: usize = module.n();
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
cols: cols,
size: size,
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert!(bytes.len() >= Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
Self {
n: module.n(),
cols: cols,
size: size,
data: Vec::new(),
ptr: bytes.as_mut_ptr() as *mut i64,
}
}
}
impl VecZnx {
/// Truncates the precision of the [VecZnx] by k bits.
///
@@ -165,11 +90,12 @@ impl VecZnx {
}
if !self.borrowing() {
self.data
self.inner
.data
.truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
}
self.size -= k / log_base2k;
self.inner.size -= k / log_base2k;
let k_rem: usize = k % log_base2k;
@@ -185,10 +111,6 @@ impl VecZnx {
copy_vec_znx_from(self, a);
}
pub fn borrowing(&self) -> bool {
self.data.len() == 0
}
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry)
}

View File

@@ -1,115 +1,71 @@
use crate::ffi::vec_znx_big;
use crate::{Backend, FFT64, Module, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement};
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, NTT120};
use std::marker::PhantomData;
const VEC_ZNX_BIG_ROWS: usize = 1;
pub struct VecZnxBig<B: Backend> {
pub data: Vec<u8>,
pub ptr: *mut u8,
pub n: usize,
pub cols: usize,
pub size: usize,
pub inner: ZnxBase,
pub _marker: PhantomData<B>,
}
impl ZnxBasics for VecZnxBig<FFT64> {}
impl<B: Backend> GetZnxBase for VecZnxBig<B> {
fn znx(&self) -> &ZnxBase {
&self.inner
}
impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
fn znx_mut(&mut self) -> &mut ZnxBase {
&mut self.inner
}
}
impl<B: Backend> ZnxInfos for VecZnxBig<B> {}
impl<B: Backend> ZnxAlloc<B> for VecZnxBig<B> {
type Scalar = u8;
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
}
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, size));
let ptr: *mut Self::Scalar = data.as_mut_ptr();
Self {
data: data,
ptr: ptr,
n: module.n(),
cols: cols,
size: size,
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
VecZnxBig {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes),
_marker: PhantomData,
}
}
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
debug_assert_eq!(
_rows, VEC_ZNX_BIG_ROWS,
"rows != {} not supported for VecZnxBig",
VEC_ZNX_BIG_ROWS
);
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
}
/// Returns a new [VecZnxBig] with the provided data as backing array.
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr())
};
unsafe {
Self {
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
size: size,
_marker: PhantomData,
}
}
}
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
Self {
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
size: size,
_marker: PhantomData,
}
}
}
impl<B: Backend> ZnxInfos for VecZnxBig<B> {
fn n(&self) -> usize {
self.n
}
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn size(&self) -> usize {
self.size
}
}
impl ZnxLayout for VecZnxBig<FFT64> {
type Scalar = i64;
fn as_ptr(&self) -> *const Self::Scalar {
self.ptr as *const Self::Scalar
}
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.ptr as *mut Self::Scalar
impl ZnxLayout for VecZnxBig<NTT120> {
type Scalar = i128;
}
impl ZnxBasics for VecZnxBig<FFT64> {}
impl ZnxSliceSize for VecZnxBig<FFT64> {
fn sl(&self) -> usize {
self.n()
}
}
impl ZnxSliceSize for VecZnxBig<NTT120> {
fn sl(&self) -> usize {
self.n() * 4
}
}
impl ZnxBasics for VecZnxBig<NTT120> {}
impl VecZnxBig<FFT64> {
pub fn print(&self, n: usize) {
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));

View File

@@ -1,5 +1,6 @@
use crate::ffi::vec_znx;
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement};
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement};
pub trait VecZnxBigOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
@@ -17,7 +18,7 @@ pub trait VecZnxBigOps<B: Backend> {
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<B>;
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBig<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
@@ -41,74 +42,74 @@ pub trait VecZnxBigOps<B: Backend> {
fn vec_znx_big_add(
&self,
res: &mut VecZnxBig<B>,
col_res: usize,
res_col: usize,
a: &VecZnxBig<B>,
col_a: usize,
a_col: usize,
b: &VecZnxBig<B>,
col_b: usize,
b_col: usize,
);
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize);
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small(
&self,
res: &mut VecZnxBig<B>,
col_res: usize,
a: &VecZnx,
col_a: usize,
b: &VecZnxBig<B>,
col_b: usize,
res_col: usize,
a: &VecZnxBig<B>,
a_col: usize,
b: &VecZnx,
b_col: usize,
);
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnx, col_a: usize);
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize);
/// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub(
&self,
res: &mut VecZnxBig<B>,
col_res: usize,
res_col: usize,
a: &VecZnxBig<B>,
col_a: usize,
a_col: usize,
b: &VecZnxBig<B>,
col_b: usize,
b_col: usize,
);
/// Subtracts `a` to `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize);
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
/// Subtracts `b` to `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize);
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
/// Subtracts `b` to `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a(
&self,
res: &mut VecZnxBig<B>,
col_res: usize,
res_col: usize,
a: &VecZnx,
col_a: usize,
a_col: usize,
b: &VecZnxBig<B>,
col_b: usize,
b_col: usize,
);
/// Subtracts `a` to `b` and stores the result on `b`.
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnx, col_a: usize);
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize);
/// Subtracts `b` to `a` and stores the result on `c`.
fn vec_znx_big_sub_small_b(
&self,
res: &mut VecZnxBig<B>,
col_res: usize,
res_col: usize,
a: &VecZnxBig<B>,
col_a: usize,
a_col: usize,
b: &VecZnx,
col_b: usize,
b_col: usize,
);
/// Subtracts `b` to `a` and stores the result on `b`.
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnx, col_a: usize);
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize);
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
@@ -123,44 +124,44 @@ pub trait VecZnxBigOps<B: Backend> {
&self,
log_base2k: usize,
res: &mut VecZnx,
col_res: usize,
res_col: usize,
a: &VecZnxBig<B>,
col_a: usize,
a_col: usize,
tmp_bytes: &mut [u8],
);
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize);
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>, col_a: usize);
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>, a_col: usize);
}
impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> {
VecZnxBig::new(self, cols, size)
VecZnxBig::new(self, 1, cols, size)
}
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes(self, cols, size, bytes)
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes(self, 1, cols, size, bytes)
}
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes)
VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes)
}
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
VecZnxBig::bytes_of(self, cols, size)
VecZnxBig::bytes_of(self, 1, cols, size)
}
fn vec_znx_big_add(
&self,
res: &mut VecZnxBig<FFT64>,
col_res: usize,
res_col: usize,
a: &VecZnxBig<FFT64>,
col_a: usize,
a_col: usize,
b: &VecZnxBig<FFT64>,
col_b: usize,
b_col: usize,
) {
#[cfg(debug_assertions)]
{
@@ -170,36 +171,33 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
vec_znx_big::vec_znx_big_add(
self.ptr,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t,
a.size() as u64,
a.sl() as u64,
b.at_ptr(col_b, 0),
b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t,
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) {
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
Self::vec_znx_big_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
}
}
fn vec_znx_big_sub(
&self,
res: &mut VecZnxBig<FFT64>,
col_res: usize,
res_col: usize,
a: &VecZnxBig<FFT64>,
col_a: usize,
a_col: usize,
b: &VecZnxBig<FFT64>,
col_b: usize,
b_col: usize,
) {
#[cfg(debug_assertions)]
{
@@ -209,43 +207,40 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
vec_znx_big::vec_znx_big_sub(
self.ptr,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t,
a.size() as u64,
a.sl() as u64,
b.at_ptr(col_b, 0),
b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t,
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) {
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
}
}
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) {
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a);
Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
}
}
fn vec_znx_big_sub_small_b(
&self,
res: &mut VecZnxBig<FFT64>,
col_res: usize,
res_col: usize,
a: &VecZnxBig<FFT64>,
col_a: usize,
a_col: usize,
b: &VecZnx,
col_b: usize,
b_col: usize,
) {
#[cfg(debug_assertions)]
{
@@ -255,36 +250,34 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
vec_znx_big::vec_znx_big_sub_small_b(
self.ptr,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
a.size() as u64,
a.sl() as u64,
b.at_ptr(col_b, 0),
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnx, col_a: usize) {
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a);
Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
}
}
fn vec_znx_big_sub_small_a(
&self,
res: &mut VecZnxBig<FFT64>,
col_res: usize,
res_col: usize,
a: &VecZnx,
col_a: usize,
a_col: usize,
b: &VecZnxBig<FFT64>,
col_b: usize,
b_col: usize,
) {
#[cfg(debug_assertions)]
{
@@ -294,36 +287,34 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
vec_znx_big::vec_znx_big_sub_small_a(
self.ptr,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(col_b, 0),
b.at_ptr(b_col * b.size(), 0) as *const vec_znx_big_t,
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnx, col_a: usize) {
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
}
}
fn vec_znx_big_add_small(
&self,
res: &mut VecZnxBig<FFT64>,
col_res: usize,
a: &VecZnx,
col_a: usize,
b: &VecZnxBig<FFT64>,
col_b: usize,
res_col: usize,
a: &VecZnxBig<FFT64>,
a_col: usize,
b: &VecZnx,
b_col: usize,
) {
#[cfg(debug_assertions)]
{
@@ -333,25 +324,23 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
vec_znx_big::vec_znx_big_add_small(
self.ptr,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
a.size() as u64,
a.sl() as u64,
b.at_ptr(col_b, 0),
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnx, a_col: usize) {
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe {
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_add_small(self, &mut *res_ptr, col_res, a, a_col, &*res_ptr, col_res);
Self::vec_znx_big_add_small(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
}
}
@@ -363,9 +352,9 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
&self,
log_base2k: usize,
res: &mut VecZnx,
col_res: usize,
res_col: usize,
a: &VecZnxBig<FFT64>,
col_a: usize,
a_col: usize,
tmp_bytes: &mut [u8],
) {
#[cfg(debug_assertions)]
@@ -376,44 +365,41 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vec_znx::vec_znx_normalize_base2k(
vec_znx_big::vec_znx_big_normalize_base2k(
self.ptr,
log_base2k as u64,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) {
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
vec_znx_big::vec_znx_big_automorphism(
self.ptr,
k,
res.at_mut_ptr(col_res, 0),
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
res.size() as u64,
res.sl() as u64,
a.at_ptr(col_a, 0),
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<FFT64>, col_a: usize) {
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<FFT64>, a_col: usize) {
unsafe {
let a_ptr: *mut VecZnxBig<FFT64> = a as *mut VecZnxBig<FFT64>;
Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a);
Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
}
}
}

View File

@@ -1,129 +1,54 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
use crate::{Backend, FFT64, Module, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement};
use crate::{VecZnx, alloc_aligned};
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnxBig};
use std::marker::PhantomData;
const VEC_ZNX_DFT_ROWS: usize = 1;
pub struct VecZnxDft<B: Backend> {
pub data: Vec<u8>,
pub ptr: *mut u8,
pub n: usize,
pub cols: usize,
pub size: usize,
inner: ZnxBase,
pub _marker: PhantomData<B>,
}
impl<B: Backend> ZnxBase<B> for VecZnxDft<B> {
impl<B: Backend> GetZnxBase for VecZnxDft<B> {
fn znx(&self) -> &ZnxBase {
&self.inner
}
fn znx_mut(&mut self) -> &mut ZnxBase {
&mut self.inner
}
}
impl<B: Backend> ZnxInfos for VecZnxDft<B> {}
impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> {
type Scalar = u8;
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
}
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, size));
let ptr: *mut Self::Scalar = data.as_mut_ptr();
Self {
data: data,
ptr: ptr,
n: module.n(),
size: size,
cols: cols,
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
VecZnxDft {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
_marker: PhantomData,
}
}
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
unsafe { bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
}
/// Returns a new [VecZnxDft] with the provided data as backing array.
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr())
}
unsafe {
Self {
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
size: size,
_marker: PhantomData,
}
}
}
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
#[cfg(debug_assertions)]
{
assert!(cols > 0);
assert!(size > 0);
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
assert_alignement(bytes.as_ptr());
}
Self {
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
n: module.n(),
cols: cols,
size: size,
_marker: PhantomData,
}
}
}
impl<B: Backend> VecZnxDft<B> {
/// Cast a [VecZnxDft] into a [VecZnxBig].
/// The returned [VecZnxBig] shares the backing array
/// with the original [VecZnxDft].
pub fn as_vec_znx_big(&mut self) -> VecZnxBig<B> {
VecZnxBig::<B> {
data: Vec::new(),
ptr: self.ptr,
n: self.n,
cols: self.cols,
size: self.size,
_marker: PhantomData,
}
}
}
impl<B: Backend> ZnxInfos for VecZnxDft<B> {
fn n(&self) -> usize {
self.n
}
fn rows(&self) -> usize {
1
}
fn cols(&self) -> usize {
self.cols
}
fn size(&self) -> usize {
self.size
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
debug_assert_eq!(
_rows, VEC_ZNX_DFT_ROWS,
"rows != {} not supported for VecZnxDft",
VEC_ZNX_DFT_ROWS
);
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
}
}
impl ZnxLayout for VecZnxDft<FFT64> {
type Scalar = f64;
fn as_ptr(&self) -> *const Self::Scalar {
self.ptr as *const Self::Scalar
}
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.ptr as *mut Self::Scalar
impl ZnxSliceSize for VecZnxDft<FFT64> {
fn sl(&self) -> usize {
self.n()
}
}
@@ -133,225 +58,21 @@ impl VecZnxDft<FFT64> {
}
}
pub trait VecZnxDftOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig<B>, a: &mut VecZnxDft<B>);
fn vec_znx_idft(&self, b: &mut VecZnxBig<B>, a: &VecZnxDft<B>, tmp_bytes: &mut [u8]);
fn vec_znx_dft(&self, b: &mut VecZnxDft<B>, a: &VecZnx);
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft<B>, a: &VecZnxDft<B>);
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft<B>, tmp_bytes: &mut [u8]);
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
}
impl VecZnxDftOps<FFT64> for Module<FFT64> {
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<FFT64> {
VecZnxDft::<FFT64>::new(&self, cols, size)
}
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes(self, cols, size, tmp_bytes)
}
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes_borrow(self, cols, size, tmp_bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
VecZnxDft::bytes_of(&self, cols, size)
}
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig<FFT64>, a: &mut VecZnxDft<FFT64>) {
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
a.ptr as *mut vec_znx_dft_t,
a.poly_count() as u64,
)
impl<B: Backend> VecZnxDft<B> {
/// Cast a [VecZnxDft] into a [VecZnxBig].
/// The returned [VecZnxBig] shares the backing array
/// with the original [VecZnxDft].
pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> {
VecZnxBig::<B> {
inner: ZnxBase {
data: Vec::new(),
ptr: self.ptr(),
n: self.n(),
rows: self.rows(),
cols: self.cols(),
size: self.size(),
},
_marker: PhantomData,
}
}
fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
}
/// b <- DFT(a)
///
/// # Panics
/// If b.cols < a_cols
fn vec_znx_dft(&self, b: &mut VecZnxDft<FFT64>, a: &VecZnx) {
unsafe {
vec_znx_dft::vec_znx_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
b.size() as u64,
a.as_ptr(),
a.size() as u64,
(a.n() * a.cols()) as u64,
)
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_idft_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_dft::vec_znx_idft(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.poly_count() as u64,
a.ptr as *const vec_znx_dft_t,
a.poly_count() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>) {
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
b.ptr as *mut vec_znx_dft_t,
b.poly_count() as u64,
a.ptr as *const vec_znx_dft_t,
a.poly_count() as u64,
[0u8; 0].as_mut_ptr(),
);
}
}
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_dft_automorphism_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr())
}
println!("{}", a.poly_count());
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
a.ptr as *mut vec_znx_dft_t,
a.poly_count() as u64,
a.ptr as *const vec_znx_dft_t,
a.poly_count() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize }
}
}
#[cfg(test)]
mod tests {
use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxLayout, alloc_aligned};
use itertools::izip;
use sampling::source::Source;
#[test]
fn test_automorphism_dft() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let size: usize = 2;
let log_base2k: usize = 17;
let mut a: VecZnx = module.new_vec_znx(1, size);
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, 0, size, &mut source);
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
let p: i64 = -5;
// a_dft <- DFT(a)
module.vec_znx_dft(&mut a_dft, &a);
// a_dft <- AUTO(a_dft)
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
// a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a, 0);
// b_dft <- DFT(AUTO(a))
module.vec_znx_dft(&mut b_dft, &a);
let a_f64: &[f64] = a_dft.raw();
let b_f64: &[f64] = b_dft.raw();
izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| {
assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs());
});
module.free()
}
}

View File

@@ -0,0 +1,140 @@
use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxAlloc;
use crate::znx_base::ZnxInfos;
use crate::znx_base::ZnxLayout;
use crate::znx_base::ZnxSliceSize;
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, assert_alignement};
pub trait VecZnxDftOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &mut VecZnxDft<B>, a_cols: usize);
fn vec_znx_idft(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxDft<B>, a_col: usize, tmp_bytes: &mut [u8]);
fn vec_znx_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &VecZnx, a_col: usize);
}
impl VecZnxDftOps<FFT64> for Module<FFT64> {
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<FFT64> {
VecZnxDft::<FFT64>::new(&self, 1, cols, size)
}
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes(self, 1, cols, size, bytes)
}
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<FFT64> {
VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
VecZnxDft::bytes_of(&self, 1, cols, size)
}
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &mut VecZnxDft<FFT64>, a_col: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(res.poly_count(), a.poly_count());
}
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t,
res.size() as u64,
a.at_ptr(a_col * a.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t,
a.size() as u64,
)
}
}
fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
}
/// b <- DFT(a)
///
/// # Panics
/// If b.cols < a_cols
fn vec_znx_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
unsafe {
vec_znx_dft::vec_znx_dft(
self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t,
res.size() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxDft<FFT64>, a_col: usize, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_idft_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_dft::vec_znx_idft(
self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t,
res.size() as u64,
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_dft::vec_znx_dft_t,
a.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
}

View File

@@ -1,5 +1,6 @@
use crate::ffi::vec_znx;
use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement, switch_degree};
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
use crate::{Backend, Module, VEC_ZNX_ROWS, VecZnx, assert_alignement};
pub trait VecZnxOps {
/// Allocates a new [VecZnx].
///
@@ -19,7 +20,7 @@ pub trait VecZnxOps {
///
/// # Panic
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx;
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnx;
/// Instantiates a new [VecZnx] from a slice of bytes.
/// The returned [VecZnx] does take ownership of the slice of bytes.
@@ -107,19 +108,19 @@ pub trait VecZnxOps {
impl<B: Backend> VecZnxOps for Module<B> {
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx {
VecZnx::new(self, cols, size)
VecZnx::new(self, VEC_ZNX_ROWS, cols, size)
}
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
VecZnx::bytes_of(self, cols, size)
VecZnx::bytes_of(self, VEC_ZNX_ROWS, cols, size)
}
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes(self, cols, size, bytes)
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnx {
VecZnx::from_bytes(self, VEC_ZNX_ROWS, cols, size, bytes)
}
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx {
VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes)
VecZnx::from_bytes_borrow(self, VEC_ZNX_ROWS, cols, size, tmp_bytes)
}
fn vec_znx_normalize_tmp_bytes(&self) -> usize {

View File

@@ -1,10 +1,37 @@
use crate::{Backend, Module, assert_alignement, cast_mut};
use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut};
use itertools::izip;
use std::cmp::min;
pub trait ZnxInfos {
pub struct ZnxBase {
/// The ring degree
pub n: usize,
/// The number of rows (in the third dimension)
pub rows: usize,
/// The number of polynomials
pub cols: usize,
/// The number of size per polynomial (a.k.a small polynomials).
pub size: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
pub data: Vec<u8>,
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
pub ptr: *mut u8,
}
pub trait GetZnxBase {
fn znx(&self) -> &ZnxBase;
fn znx_mut(&mut self) -> &mut ZnxBase;
}
pub trait ZnxInfos: GetZnxBase {
/// Returns the ring degree of the polynomials.
fn n(&self) -> usize;
fn n(&self) -> usize {
self.znx().n
}
/// Returns the base two logarithm of the ring dimension of the polynomials.
fn log_n(&self) -> usize {
@@ -12,41 +39,104 @@ pub trait ZnxInfos {
}
/// Returns the number of rows.
fn rows(&self) -> usize;
fn rows(&self) -> usize {
self.znx().rows
}
/// Returns the number of polynomials in each row.
fn cols(&self) -> usize;
fn cols(&self) -> usize {
self.znx().cols
}
/// Returns the number of size per polynomial.
fn size(&self) -> usize;
fn size(&self) -> usize {
self.znx().size
}
fn data(&self) -> &[u8] {
&self.znx().data
}
fn ptr(&self) -> *mut u8 {
self.znx().ptr
}
/// Returns the total number of small polynomials.
fn poly_count(&self) -> usize {
self.rows() * self.cols() * self.size()
}
}
pub trait ZnxSliceSize {
/// Returns the slice size, which is the offset between
/// two size of the same column.
fn sl(&self) -> usize {
self.n() * self.cols()
fn sl(&self) -> usize;
}
impl ZnxBase {
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
res.data = bytes;
res
}
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(n & (n - 1), 0, "n must be a power of two");
assert!(n > 0, "n must be greater than 0");
assert!(rows > 0, "rows must be greater than 0");
assert!(cols > 0, "cols must be greater than 0");
assert!(size > 0, "size must be greater than 0");
}
Self {
n: n,
rows: rows,
cols: cols,
size: size,
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
}
}
}
pub trait ZnxBase<B: Backend> {
pub trait ZnxAlloc<B: Backend>
where
Self: Sized + ZnxInfos,
{
type Scalar;
fn new(module: &Module<B>, cols: usize, size: usize) -> Self;
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize;
fn new(module: &Module<B>, rows: usize, cols: usize, size: usize) -> Self {
let bytes: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(module, rows, cols, size));
Self::from_bytes(module, rows, cols, size, bytes)
}
fn from_bytes(module: &Module<B>, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes);
res.znx_mut().data = bytes;
res
}
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize;
}
pub trait ZnxLayout: ZnxInfos {
type Scalar;
/// Returns true if the receiver is only borrowing the data.
fn borrowing(&self) -> bool {
self.znx().data.len() == 0
}
/// Returns a non-mutable pointer to the underlying coefficients array.
fn as_ptr(&self) -> *const Self::Scalar;
fn as_ptr(&self) -> *const Self::Scalar {
self.znx().ptr as *const Self::Scalar
}
/// Returns a mutable pointer to the underlying coefficients array.
fn as_mut_ptr(&mut self) -> *mut Self::Scalar;
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.znx_mut().ptr as *mut Self::Scalar
}
/// Returns a non-mutable reference to the entire underlying coefficient array.
fn raw(&self) -> &[Self::Scalar] {

View File

@@ -1,5 +1,3 @@
cargo-features = ["edition2024"]
[package]
name = "rlwe"
version = "0.1.0"

View File

@@ -20,7 +20,7 @@ pub struct AutomorphismKey {
}
pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
module.bytes_of_scalar() + module.bytes_of_scalar_znx_dft() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
}
impl Parameters {
@@ -103,10 +103,10 @@ impl AutomorphismKey {
tmp_bytes: &mut [u8],
) -> Vec<Self> {
let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar());
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol());
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar_znx_dft());
let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes);
let mut sk_out: ScalarZnxDft = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes);
let mut sk_out: ScalarZnxDft = module.new_scalar_znx_dft_from_bytes_borrow(sk_out_bytes);
let mut keys: Vec<AutomorphismKey> = Vec::new();
@@ -116,7 +116,7 @@ impl AutomorphismKey {
let p_inv: i64 = module.galois_element_inv(*pi);
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
module.svp_prepare(&mut sk_out, &sk_auto);
module.scalar_znx_dft_prepare(&mut sk_out, &sk_auto);
encrypt_grlwe_sk(
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
);

View File

@@ -20,7 +20,7 @@ impl SecretKey {
}
pub fn prepare(&self, module: &Module, sk_ppol: &mut ScalarZnxDft) {
module.svp_prepare(sk_ppol, &self.0)
module.scalar_znx_dft_prepare(sk_ppol, &self.0)
}
}