mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip major refactoring (compiles & all test + example passing)
This commit is contained in:
@@ -23,10 +23,10 @@ fn main() {
|
||||
s.fill_ternary_prob(0.5, &mut source);
|
||||
|
||||
// Buffer to store s in the DFT domain
|
||||
let mut s_ppol: ScalarZnxDft<FFT64> = module.new_svp_ppol();
|
||||
let mut s_dft: ScalarZnxDft<FFT64> = module.new_scalar_znx_dft();
|
||||
|
||||
// s_ppol <- DFT(s)
|
||||
module.svp_prepare(&mut s_ppol, &s);
|
||||
// s_dft <- DFT(s)
|
||||
module.svp_prepare(&mut s_dft, &s);
|
||||
|
||||
// Allocates a VecZnx with two columns: ct=(0, 0)
|
||||
let mut ct: VecZnx = module.new_vec_znx(
|
||||
@@ -46,16 +46,17 @@ fn main() {
|
||||
// Applies DFT(ct[1]) * DFT(s)
|
||||
module.svp_apply_dft(
|
||||
&mut buf_dft, // DFT(ct[1] * s)
|
||||
&s_ppol, // DFT(s)
|
||||
0, // Selects the first column of res
|
||||
&s_dft, // DFT(s)
|
||||
&ct,
|
||||
1, // Selects the second column of ct
|
||||
);
|
||||
|
||||
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
|
||||
let mut buf_big: VecZnxBig<FFT64> = buf_dft.as_vec_znx_big();
|
||||
let mut buf_big: VecZnxBig<FFT64> = buf_dft.alias_as_vec_znx_big();
|
||||
|
||||
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// Creates a plaintext: VecZnx with 1 column
|
||||
let mut m: VecZnx = module.new_vec_znx(
|
||||
@@ -103,13 +104,14 @@ fn main() {
|
||||
// DFT(ct[1] * s)
|
||||
module.svp_apply_dft(
|
||||
&mut buf_dft,
|
||||
&s_ppol,
|
||||
0, // Selects the first column of res.
|
||||
&s_dft,
|
||||
&ct,
|
||||
1, // Selects the second column of ct (ct[1])
|
||||
);
|
||||
|
||||
// BIG(c1 * s) = IDFT(DFT(c1 * s))
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// BIG(c1 * s) + ct[0]
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
|
||||
|
||||
@@ -42,8 +42,8 @@ fn main() {
|
||||
let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs_mat);
|
||||
module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf);
|
||||
|
||||
let mut c_big: VecZnxBig<FFT64> = c_dft.as_vec_znx_big();
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft);
|
||||
let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, 0, &mut c_dft, 0);
|
||||
|
||||
let mut res: VecZnx = module.new_vec_znx(1, limbs_vec);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::ffi::znx::znx_zero_i64_ref;
|
||||
use crate::{VecZnx, ZnxInfos, ZnxLayout};
|
||||
use crate::znx_base::ZnxLayout;
|
||||
use crate::{VecZnx, znx_base::ZnxInfos};
|
||||
use itertools::izip;
|
||||
use rug::{Assign, Float};
|
||||
use std::cmp::min;
|
||||
@@ -262,7 +263,10 @@ fn decode_coeff_i64(a: &VecZnx, col_i: usize, log_base2k: usize, log_k: usize, i
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{Encoding, FFT64, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout};
|
||||
use crate::{
|
||||
Encoding, FFT64, Module, VecZnx, VecZnxOps,
|
||||
znx_base::{ZnxInfos, ZnxLayout},
|
||||
};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -273,7 +277,7 @@ mod tests {
|
||||
let log_base2k: usize = 17;
|
||||
let size: usize = 5;
|
||||
let log_k: usize = size * log_base2k - 5;
|
||||
let mut a: VecZnx = VecZnx::new(&module, 2, size);
|
||||
let mut a: VecZnx = module.new_vec_znx(2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
@@ -295,7 +299,7 @@ mod tests {
|
||||
let log_base2k: usize = 17;
|
||||
let size: usize = 5;
|
||||
let log_k: usize = size * log_base2k - 5;
|
||||
let mut a: VecZnx = VecZnx::new(&module, 2, size);
|
||||
let mut a: VecZnx = module.new_vec_znx(2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
pub mod commons;
|
||||
pub mod encoding;
|
||||
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
||||
// Other modules and exports
|
||||
@@ -12,9 +11,10 @@ pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_big_ops;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vec_znx_dft_ops;
|
||||
pub mod vec_znx_ops;
|
||||
pub mod znx_base;
|
||||
|
||||
pub use commons::*;
|
||||
pub use encoding::*;
|
||||
pub use mat_znx_dft::*;
|
||||
pub use module::*;
|
||||
@@ -26,7 +26,9 @@ pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_big_ops::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vec_znx_dft_ops::*;
|
||||
pub use vec_znx_ops::*;
|
||||
pub use znx_base::*;
|
||||
|
||||
pub const GALOISGENERATOR: u64 = 5;
|
||||
pub const DEFAULTALIGN: usize = 64;
|
||||
@@ -110,14 +112,8 @@ pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||
unsafe { Vec::from_raw_parts(ptr, len, cap) }
|
||||
}
|
||||
|
||||
// Allocates an aligned of size equal to the smallest power of two equal or greater to `size` that is
|
||||
// at least as bit as DEFAULTALIGN / std::mem::size_of::<T>().
|
||||
/// Allocates an aligned of size equal to the smallest multiple
|
||||
/// of [DEFAULTALIGN] that is equal or greater to `size`.
|
||||
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
||||
alloc_aligned_custom::<T>(
|
||||
std::cmp::max(
|
||||
size.next_power_of_two(),
|
||||
DEFAULTALIGN / std::mem::size_of::<T>(),
|
||||
),
|
||||
DEFAULTALIGN,
|
||||
)
|
||||
alloc_aligned_custom::<T>(size + (size % DEFAULTALIGN), DEFAULTALIGN)
|
||||
}
|
||||
|
||||
@@ -1,103 +1,75 @@
|
||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp::{self, vmp_pmat_t};
|
||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement};
|
||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
/// Each col of the [VmpPMat] can be seen as a collection of [VecZnxDft].
|
||||
/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft].
|
||||
///
|
||||
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat].
|
||||
/// See the trait [VmpPMatOps] for additional information.
|
||||
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
|
||||
/// See the trait [MatZnxDftOps] for additional information.
|
||||
pub struct MatZnxDft<B: Backend> {
|
||||
/// Raw data, is empty if borrowing scratch space.
|
||||
data: Vec<u8>,
|
||||
/// Pointer to data. Can point to scratch space.
|
||||
ptr: *mut u8,
|
||||
/// The ring degree of each polynomial.
|
||||
n: usize,
|
||||
/// Number of rows
|
||||
rows: usize,
|
||||
/// Number of cols
|
||||
cols: usize,
|
||||
/// The number of small polynomials
|
||||
size: usize,
|
||||
pub inner: ZnxBase,
|
||||
_marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxInfos for MatZnxDft<B> {
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
impl<B: Backend> GetZnxBase for MatZnxDft<B> {
|
||||
fn znx(&self) -> &ZnxBase {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
fn znx_mut(&mut self) -> &mut ZnxBase {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl MatZnxDft<FFT64> {
|
||||
fn new(module: &Module<FFT64>, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(module.bytes_of_mat_znx_dft(rows, cols, size));
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
MatZnxDft::<FFT64> {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
n: module.n(),
|
||||
rows: rows,
|
||||
cols: cols,
|
||||
size: size,
|
||||
impl<B: Backend> ZnxInfos for MatZnxDft<B> {}
|
||||
|
||||
impl ZnxSliceSize for MatZnxDft<FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxLayout for MatZnxDft<FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxAlloc<B> for MatZnxDft<B> {
|
||||
type Scalar = u8;
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
Self {
|
||||
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.ptr
|
||||
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize {
|
||||
unsafe { vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_mut_ptr(&self) -> *mut u8 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn borrowed(&self) -> bool {
|
||||
self.data.len() == 0
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference to the entire contiguous array of the [VmpPMat].
|
||||
pub fn raw(&self) -> &[f64] {
|
||||
let ptr: *const f64 = self.ptr as *const f64;
|
||||
let size: usize = self.n() * self.poly_count();
|
||||
unsafe { &std::slice::from_raw_parts(ptr, size) }
|
||||
}
|
||||
|
||||
/// Returns a mutable reference of to the entire contiguous array of the [VmpPMat].
|
||||
pub fn raw_mut(&self) -> &mut [f64] {
|
||||
let ptr: *mut f64 = self.ptr as *mut f64;
|
||||
let size: usize = self.n() * self.poly_count();
|
||||
unsafe { std::slice::from_raw_parts_mut(ptr, size) }
|
||||
}
|
||||
|
||||
/// Returns a copy of the backend array at index (i, j) of the [VmpPMat].
|
||||
impl MatZnxDft<FFT64> {
|
||||
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `row`: row index (i).
|
||||
/// * `col`: col index (j).
|
||||
pub fn at(&self, row: usize, col: usize) -> Vec<f64> {
|
||||
let mut res: Vec<f64> = alloc_aligned(self.n);
|
||||
#[allow(dead_code)]
|
||||
fn at(&self, row: usize, col: usize) -> Vec<f64> {
|
||||
let n: usize = self.n();
|
||||
|
||||
if self.n < 8 {
|
||||
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)]);
|
||||
let mut res: Vec<f64> = alloc_aligned(n);
|
||||
|
||||
if n < 8 {
|
||||
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]);
|
||||
} else {
|
||||
(0..self.n >> 3).for_each(|blk| {
|
||||
(0..n >> 3).for_each(|blk| {
|
||||
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
|
||||
});
|
||||
}
|
||||
@@ -105,6 +77,7 @@ impl MatZnxDft<FFT64> {
|
||||
res
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
|
||||
let nrows: usize = self.rows();
|
||||
let nsize: usize = self.size();
|
||||
@@ -117,11 +90,11 @@ impl MatZnxDft<FFT64> {
|
||||
}
|
||||
|
||||
/// This trait implements methods for vector matrix product,
|
||||
/// that is, multiplying a [VecZnx] with a [VmpPMat].
|
||||
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
||||
pub trait MatZnxDftOps<B: Backend> {
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize;
|
||||
|
||||
/// Allocates a new [VmpPMat] with the given number of rows and columns.
|
||||
/// Allocates a new [MatZnxDft] with the given number of rows and columns.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -129,83 +102,83 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
/// * `size`: number of size (number of size of each [VecZnxDft]).
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>;
|
||||
|
||||
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
|
||||
/// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous].
|
||||
/// * `size`: number of size of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous].
|
||||
/// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
|
||||
/// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize;
|
||||
|
||||
/// Prepares a [VmpPMat] from a contiguous array of [i64].
|
||||
/// Prepares a [MatZnxDft] from a contiguous array of [i64].
|
||||
/// The helper struct [Matrix3D] can be used to contruct and populate
|
||||
/// the appropriate contiguous array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat].
|
||||
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
/// * `b`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft].
|
||||
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft<B>, a: &[i64], buf: &mut [u8]);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a [VecZnx].
|
||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat].
|
||||
/// * `b`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||
///
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_row(&self, b: &mut MatZnxDft<B>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
|
||||
|
||||
/// Extracts the ith-row of [VmpPMat] into a [VecZnxBig].
|
||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat].
|
||||
/// * `a`: [VmpPMat] on which the values are encoded.
|
||||
/// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft].
|
||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row(&self, b: &mut VecZnxBig<B>, a: &MatZnxDft<B>, row_i: usize);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a [VecZnxDft].
|
||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the [VecZnxDft] to encode on the [VmpPMat].
|
||||
/// * `b`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
///
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<B>, a: &VecZnxDft<B>, row_i: usize);
|
||||
|
||||
/// Extracts the ith-row of [VmpPMat] into a [VecZnxDft].
|
||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat].
|
||||
/// * `a`: [VmpPMat] on which the values are encoded.
|
||||
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, row_i: usize);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
|
||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
||||
/// * `a_size`: number of size of the input [VecZnx].
|
||||
/// * `rows`: number of rows of the input [VmpPMat].
|
||||
/// * `size`: number of size of the input [VmpPMat].
|
||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||
/// * `size`: number of size of the input [MatZnxDft].
|
||||
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
|
||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
@@ -221,17 +194,17 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
///
|
||||
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
||||
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
|
||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver.
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver.
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
|
||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
@@ -247,28 +220,28 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
///
|
||||
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
||||
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
|
||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
||||
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
|
||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
||||
/// * `a_size`: number of size of the input [VecZnxDft].
|
||||
/// * `rows`: number of rows of the input [VmpPMat].
|
||||
/// * `size`: number of size of the input [VmpPMat].
|
||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||
/// * `size`: number of size of the input [MatZnxDft].
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
|
||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
@@ -284,18 +257,18 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
///
|
||||
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it.
|
||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it.
|
||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
|
||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
@@ -311,18 +284,18 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
///
|
||||
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
|
||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place.
|
||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` size and a [VmpPMat] of `i` rows and
|
||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
@@ -337,8 +310,8 @@ pub trait MatZnxDftOps<B: Backend> {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the input and output of the vector matrix product, as a [VecZnxDft].
|
||||
/// * `a`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
/// * `a`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, buf: &mut [u8]);
|
||||
}
|
||||
|
||||
@@ -404,7 +377,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
unsafe {
|
||||
vmp::vmp_extract_row(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.as_mut_ptr() as *mut vec_znx_big_t,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
@@ -423,7 +396,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
vmp::vmp_prepare_row_dft(
|
||||
self.ptr,
|
||||
b.as_mut_ptr() as *mut vmp_pmat_t,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.size() as u64,
|
||||
@@ -440,7 +413,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
unsafe {
|
||||
vmp::vmp_extract_row_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
@@ -470,7 +443,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
c.size() as u64,
|
||||
a.as_ptr(),
|
||||
a.size() as u64,
|
||||
@@ -492,7 +465,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_add(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
c.size() as u64,
|
||||
a.as_ptr(),
|
||||
a.size() as u64,
|
||||
@@ -526,9 +499,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
c.size() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
a.size() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
@@ -553,9 +526,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
c.size() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
a.size() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
@@ -574,9 +547,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.as_ptr() as *mut vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
a.rows() as u64,
|
||||
@@ -591,7 +564,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
mod tests {
|
||||
use crate::{
|
||||
FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
|
||||
ZnxLayout, alloc_aligned,
|
||||
alloc_aligned, znx_base::ZnxLayout,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -614,7 +587,7 @@ mod tests {
|
||||
for row_i in 0..vpmat_rows {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source);
|
||||
module.vec_znx_dft(&mut a_dft, &a);
|
||||
module.vec_znx_dft(&mut a_dft, 0, &a, 0);
|
||||
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
|
||||
|
||||
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
|
||||
@@ -627,7 +600,7 @@ mod tests {
|
||||
|
||||
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
||||
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
|
||||
module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes);
|
||||
module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes);
|
||||
assert_eq!(a_big.raw(), b_big.raw());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{Backend, Module, VecZnx, ZnxLayout};
|
||||
use crate::{Backend, Module, VecZnx, znx_base::ZnxLayout};
|
||||
use rand_distr::{Distribution, Normal};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -106,7 +106,7 @@ impl<B: Backend> Sampling for Module<B> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Sampling;
|
||||
use crate::{FFT64, Module, Stats, VecZnx, ZnxBase, ZnxLayout};
|
||||
use crate::{FFT64, Module, Stats, VecZnx, VecZnxOps, znx_base::ZnxLayout};
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
@@ -120,7 +120,7 @@ mod tests {
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let one_12_sqrt: f64 = 0.28867513459481287;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx = VecZnx::new(&module, cols, size);
|
||||
let mut a: VecZnx = module.new_vec_znx(cols, size);
|
||||
module.fill_uniform(log_base2k, &mut a, col_i, size, &mut source);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
@@ -154,7 +154,7 @@ mod tests {
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << log_k as u64) as f64;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx = VecZnx::new(&module, cols, size);
|
||||
let mut a: VecZnx = module.new_vec_znx(cols, size);
|
||||
module.add_normal(log_base2k, &mut a, col_i, log_k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
|
||||
@@ -2,9 +2,8 @@ use std::marker::PhantomData;
|
||||
|
||||
use crate::ffi::svp::{self, svp_ppol_t};
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, ZnxLayout, assert_alignement};
|
||||
|
||||
use crate::{ZnxInfos, alloc_aligned, cast_mut};
|
||||
use crate::znx_base::{ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement, cast_mut};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
@@ -118,11 +117,14 @@ impl Scalar {
|
||||
|
||||
pub fn as_vec_znx(&self) -> VecZnx {
|
||||
VecZnx {
|
||||
inner: ZnxBase {
|
||||
n: self.n,
|
||||
rows: 1,
|
||||
cols: 1,
|
||||
size: 1,
|
||||
data: Vec::new(),
|
||||
ptr: self.ptr,
|
||||
ptr: self.ptr as *mut u8,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -159,7 +161,7 @@ pub struct ScalarZnxDft<B: Backend> {
|
||||
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
|
||||
impl ScalarZnxDft<FFT64> {
|
||||
pub fn new(module: &Module<FFT64>) -> Self {
|
||||
module.new_svp_ppol()
|
||||
module.new_scalar_znx_dft()
|
||||
}
|
||||
|
||||
/// Returns the ring degree of the [SvpPPol].
|
||||
@@ -168,14 +170,14 @@ impl ScalarZnxDft<FFT64> {
|
||||
}
|
||||
|
||||
pub fn bytes_of(module: &Module<FFT64>) -> usize {
|
||||
module.bytes_of_svp_ppol()
|
||||
module.bytes_of_scalar_znx_dft()
|
||||
}
|
||||
|
||||
pub fn from_bytes(module: &Module<FFT64>, bytes: &mut [u8]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(bytes.as_ptr());
|
||||
assert_eq!(bytes.len(), module.bytes_of_svp_ppol());
|
||||
assert_eq!(bytes.len(), module.bytes_of_scalar_znx_dft());
|
||||
}
|
||||
unsafe {
|
||||
Self {
|
||||
@@ -191,7 +193,7 @@ impl ScalarZnxDft<FFT64> {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol());
|
||||
assert_eq!(tmp_bytes.len(), module.bytes_of_scalar_znx_dft());
|
||||
}
|
||||
Self {
|
||||
n: module.n(),
|
||||
@@ -209,33 +211,33 @@ impl ScalarZnxDft<FFT64> {
|
||||
|
||||
pub trait ScalarZnxDftOps<B: Backend> {
|
||||
/// Allocates a new [SvpPPol].
|
||||
fn new_svp_ppol(&self) -> ScalarZnxDft<B>;
|
||||
fn new_scalar_znx_dft(&self) -> ScalarZnxDft<B>;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
|
||||
fn bytes_of_svp_ppol(&self) -> usize;
|
||||
fn bytes_of_scalar_znx_dft(&self) -> usize;
|
||||
|
||||
/// Allocates a new [SvpPPol] from an array of bytes.
|
||||
/// The array of bytes is owned by the [SvpPPol].
|
||||
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
|
||||
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<B>;
|
||||
fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<B>;
|
||||
|
||||
/// Allocates a new [SvpPPol] from an array of bytes.
|
||||
/// The array of bytes is borrowed by the [SvpPPol].
|
||||
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
|
||||
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<B>;
|
||||
fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<B>;
|
||||
|
||||
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
|
||||
fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft<B>, a: &Scalar);
|
||||
|
||||
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
|
||||
/// the [VecZnxDft] is multiplied with [SvpPPol].
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &ScalarZnxDft<B>, b: &VecZnx, b_col: usize);
|
||||
fn svp_apply_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &ScalarZnxDft<B>, b: &VecZnx, b_col: usize);
|
||||
}
|
||||
|
||||
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn new_svp_ppol(&self) -> ScalarZnxDft<FFT64> {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_svp_ppol());
|
||||
fn new_scalar_znx_dft(&self) -> ScalarZnxDft<FFT64> {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_scalar_znx_dft());
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
ScalarZnxDft::<FFT64> {
|
||||
data: data,
|
||||
@@ -245,28 +247,28 @@ impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of_svp_ppol(&self) -> usize {
|
||||
fn bytes_of_scalar_znx_dft(&self) -> usize {
|
||||
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
|
||||
}
|
||||
|
||||
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
|
||||
fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
|
||||
ScalarZnxDft::from_bytes(self, bytes)
|
||||
}
|
||||
|
||||
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
|
||||
fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
|
||||
ScalarZnxDft::from_bytes_borrow(self, tmp_bytes)
|
||||
}
|
||||
|
||||
fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft<FFT64>, a: &Scalar) {
|
||||
unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) }
|
||||
fn svp_prepare(&self, res: &mut ScalarZnxDft<FFT64>, a: &Scalar) {
|
||||
unsafe { svp::svp_prepare(self.ptr, res.ptr as *mut svp_ppol_t, a.as_ptr()) }
|
||||
}
|
||||
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &ScalarZnxDft<FFT64>, b: &VecZnx, b_col: usize) {
|
||||
fn svp_apply_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &ScalarZnxDft<FFT64>, b: &VecZnx, b_col: usize) {
|
||||
unsafe {
|
||||
svp::svp_apply_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.size() as u64,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
a.ptr as *const svp_ppol_t,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::{Encoding, VecZnx, ZnxInfos};
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Encoding, VecZnx};
|
||||
use rug::Float;
|
||||
use rug::float::Round;
|
||||
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
use crate::Backend;
|
||||
use crate::ZnxBase;
|
||||
use crate::Module;
|
||||
use crate::assert_alignement;
|
||||
use crate::cast_mut;
|
||||
use crate::ffi::znx;
|
||||
use crate::switch_degree;
|
||||
use crate::{Module, ZnxBasics, ZnxInfos, ZnxLayout};
|
||||
use crate::{alloc_aligned, assert_alignement};
|
||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
|
||||
use std::cmp::min;
|
||||
|
||||
pub const VEC_ZNX_ROWS: usize = 1;
|
||||
|
||||
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
|
||||
/// Zn\[X\] with [i64] coefficients.
|
||||
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
|
||||
@@ -17,56 +18,54 @@ use std::cmp::min;
|
||||
/// Given 3 polynomials (a, b, c) of Zn\[X\], each with 4 columns, then the memory
|
||||
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
|
||||
/// are small polynomials of Zn\[X\].
|
||||
#[derive(Clone)]
|
||||
pub struct VecZnx {
|
||||
/// Polynomial degree.
|
||||
pub n: usize,
|
||||
|
||||
/// The number of polynomials
|
||||
pub cols: usize,
|
||||
|
||||
/// The number of size per polynomial (a.k.a small polynomials).
|
||||
pub size: usize,
|
||||
|
||||
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
|
||||
pub data: Vec<i64>,
|
||||
|
||||
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
|
||||
pub ptr: *mut i64,
|
||||
pub inner: ZnxBase,
|
||||
}
|
||||
|
||||
impl ZnxInfos for VecZnx {
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
impl GetZnxBase for VecZnx {
|
||||
fn znx(&self) -> &ZnxBase {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
fn znx_mut(&mut self) -> &mut ZnxBase {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
impl ZnxInfos for VecZnx {}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
impl ZnxSliceSize for VecZnx {
|
||||
fn sl(&self) -> usize {
|
||||
self.cols() * self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxLayout for VecZnx {
|
||||
type Scalar = i64;
|
||||
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
self.ptr
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxBasics for VecZnx {}
|
||||
|
||||
impl<B: Backend> ZnxAlloc<B> for VecZnx {
|
||||
type Scalar = i64;
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
|
||||
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
|
||||
VecZnx {
|
||||
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes),
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
|
||||
debug_assert_eq!(
|
||||
_rows, VEC_ZNX_ROWS,
|
||||
"rows != {} not supported for VecZnx",
|
||||
VEC_ZNX_ROWS
|
||||
);
|
||||
module.n() * cols * size * size_of::<Self::Scalar>()
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies the coefficients of `a` on the receiver.
|
||||
/// Copy is done with the minimum size matching both backing arrays.
|
||||
/// Panics if the cols do not match.
|
||||
@@ -78,80 +77,6 @@ pub fn copy_vec_znx_from(b: &mut VecZnx, a: &VecZnx) {
|
||||
data_b[..size].copy_from_slice(&data_a[..size])
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxBase<B> for VecZnx {
|
||||
type Scalar = i64;
|
||||
|
||||
/// Allocates a new [VecZnx] composed of #size polynomials of Z\[X\].
|
||||
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
let n: usize = module.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n > 0);
|
||||
assert!(n & (n - 1) == 0);
|
||||
assert!(cols > 0);
|
||||
assert!(size > 0);
|
||||
}
|
||||
let mut data: Vec<i64> = alloc_aligned::<i64>(Self::bytes_of(module, cols, size));
|
||||
let ptr: *mut i64 = data.as_mut_ptr();
|
||||
Self {
|
||||
n: n,
|
||||
cols: cols,
|
||||
size: size,
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
module.n() * cols * size * size_of::<i64>()
|
||||
}
|
||||
|
||||
/// Returns a new struct implementing [VecZnx] with the provided data as backing array.
|
||||
///
|
||||
/// The struct will take ownership of buf[..[Self::bytes_of]]
|
||||
///
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is equal to [Self::bytes_of].
|
||||
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
let n: usize = module.n();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
|
||||
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
|
||||
Self {
|
||||
n: n,
|
||||
cols: cols,
|
||||
size: size,
|
||||
data: Vec::from_raw_parts(ptr, bytes.len(), bytes.len()),
|
||||
ptr: ptr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(size > 0);
|
||||
assert!(bytes.len() >= Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
Self {
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
size: size,
|
||||
data: Vec::new(),
|
||||
ptr: bytes.as_mut_ptr() as *mut i64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnx {
|
||||
/// Truncates the precision of the [VecZnx] by k bits.
|
||||
///
|
||||
@@ -165,11 +90,12 @@ impl VecZnx {
|
||||
}
|
||||
|
||||
if !self.borrowing() {
|
||||
self.data
|
||||
self.inner
|
||||
.data
|
||||
.truncate(self.n() * self.cols() * (self.size() - k / log_base2k));
|
||||
}
|
||||
|
||||
self.size -= k / log_base2k;
|
||||
self.inner.size -= k / log_base2k;
|
||||
|
||||
let k_rem: usize = k % log_base2k;
|
||||
|
||||
@@ -185,10 +111,6 @@ impl VecZnx {
|
||||
copy_vec_znx_from(self, a);
|
||||
}
|
||||
|
||||
pub fn borrowing(&self) -> bool {
|
||||
self.data.len() == 0
|
||||
}
|
||||
|
||||
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
|
||||
normalize(log_base2k, self, carry)
|
||||
}
|
||||
|
||||
@@ -1,115 +1,71 @@
|
||||
use crate::ffi::vec_znx_big;
|
||||
use crate::{Backend, FFT64, Module, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, alloc_aligned, assert_alignement};
|
||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||
use crate::{Backend, FFT64, Module, NTT120};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const VEC_ZNX_BIG_ROWS: usize = 1;
|
||||
|
||||
pub struct VecZnxBig<B: Backend> {
|
||||
pub data: Vec<u8>,
|
||||
pub ptr: *mut u8,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub size: usize,
|
||||
pub inner: ZnxBase,
|
||||
pub _marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl ZnxBasics for VecZnxBig<FFT64> {}
|
||||
|
||||
impl<B: Backend> ZnxBase<B> for VecZnxBig<B> {
|
||||
type Scalar = u8;
|
||||
|
||||
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(size > 0);
|
||||
}
|
||||
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, size));
|
||||
let ptr: *mut Self::Scalar = data.as_mut_ptr();
|
||||
Self {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
impl<B: Backend> GetZnxBase for VecZnxBig<B> {
|
||||
fn znx(&self) -> &ZnxBase {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided data as backing array.
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
|
||||
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr())
|
||||
};
|
||||
unsafe {
|
||||
Self {
|
||||
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
Self {
|
||||
data: Vec::new(),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
fn znx_mut(&mut self) -> &mut ZnxBase {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxInfos for VecZnxBig<B> {
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
impl<B: Backend> ZnxInfos for VecZnxBig<B> {}
|
||||
|
||||
impl<B: Backend> ZnxAlloc<B> for VecZnxBig<B> {
|
||||
type Scalar = u8;
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
VecZnxBig {
|
||||
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
|
||||
debug_assert_eq!(
|
||||
_rows, VEC_ZNX_BIG_ROWS,
|
||||
"rows != {} not supported for VecZnxBig",
|
||||
VEC_ZNX_BIG_ROWS
|
||||
);
|
||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxLayout for VecZnxBig<FFT64> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
self.ptr as *const Self::Scalar
|
||||
}
|
||||
impl ZnxLayout for VecZnxBig<NTT120> {
|
||||
type Scalar = i128;
|
||||
}
|
||||
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
self.ptr as *mut Self::Scalar
|
||||
impl ZnxBasics for VecZnxBig<FFT64> {}
|
||||
|
||||
impl ZnxSliceSize for VecZnxBig<FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSliceSize for VecZnxBig<NTT120> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * 4
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxBasics for VecZnxBig<NTT120> {}
|
||||
|
||||
impl VecZnxBig<FFT64> {
|
||||
pub fn print(&self, n: usize) {
|
||||
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement};
|
||||
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
|
||||
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement};
|
||||
|
||||
pub trait VecZnxBigOps<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||
@@ -17,7 +18,7 @@ pub trait VecZnxBigOps<B: Backend> {
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBig<B>;
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
///
|
||||
@@ -41,74 +42,74 @@ pub trait VecZnxBigOps<B: Backend> {
|
||||
fn vec_znx_big_add(
|
||||
&self,
|
||||
res: &mut VecZnxBig<B>,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<B>,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
b: &VecZnxBig<B>,
|
||||
col_b: usize,
|
||||
b_col: usize,
|
||||
);
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize);
|
||||
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small(
|
||||
&self,
|
||||
res: &mut VecZnxBig<B>,
|
||||
col_res: usize,
|
||||
a: &VecZnx,
|
||||
col_a: usize,
|
||||
b: &VecZnxBig<B>,
|
||||
col_b: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<B>,
|
||||
a_col: usize,
|
||||
b: &VecZnx,
|
||||
b_col: usize,
|
||||
);
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnx, col_a: usize);
|
||||
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize);
|
||||
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub(
|
||||
&self,
|
||||
res: &mut VecZnxBig<B>,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<B>,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
b: &VecZnxBig<B>,
|
||||
col_b: usize,
|
||||
b_col: usize,
|
||||
);
|
||||
|
||||
/// Subtracts `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize);
|
||||
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
|
||||
|
||||
/// Subtracts `b` to `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize);
|
||||
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
|
||||
|
||||
/// Subtracts `b` to `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a(
|
||||
&self,
|
||||
res: &mut VecZnxBig<B>,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnx,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
b: &VecZnxBig<B>,
|
||||
col_b: usize,
|
||||
b_col: usize,
|
||||
);
|
||||
|
||||
/// Subtracts `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnx, col_a: usize);
|
||||
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize);
|
||||
|
||||
/// Subtracts `b` to `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b(
|
||||
&self,
|
||||
res: &mut VecZnxBig<B>,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<B>,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
b: &VecZnx,
|
||||
col_b: usize,
|
||||
b_col: usize,
|
||||
);
|
||||
|
||||
/// Subtracts `b` to `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnx, col_a: usize);
|
||||
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnx, a_col: usize);
|
||||
|
||||
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
||||
@@ -123,44 +124,44 @@ pub trait VecZnxBigOps<B: Backend> {
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
res: &mut VecZnx,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<B>,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
);
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<B>, col_res: usize, a: &VecZnxBig<B>, col_a: usize);
|
||||
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxBig<B>, a_col: usize);
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>, col_a: usize);
|
||||
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<B>, a_col: usize);
|
||||
}
|
||||
|
||||
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::new(self, cols, size)
|
||||
VecZnxBig::new(self, 1, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::from_bytes(self, cols, size, bytes)
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::from_bytes(self, 1, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
||||
VecZnxBig::from_bytes_borrow(self, cols, size, tmp_bytes)
|
||||
VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
||||
VecZnxBig::bytes_of(self, cols, size)
|
||||
VecZnxBig::bytes_of(self, 1, cols, size)
|
||||
}
|
||||
|
||||
fn vec_znx_big_add(
|
||||
&self,
|
||||
res: &mut VecZnxBig<FFT64>,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<FFT64>,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
b: &VecZnxBig<FFT64>,
|
||||
col_b: usize,
|
||||
b_col: usize,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -170,36 +171,33 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
vec_znx_big::vec_znx_big_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t,
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(col_b, 0),
|
||||
b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t,
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) {
|
||||
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
|
||||
unsafe {
|
||||
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
|
||||
Self::vec_znx_big_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
|
||||
Self::vec_znx_big_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub(
|
||||
&self,
|
||||
res: &mut VecZnxBig<FFT64>,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<FFT64>,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
b: &VecZnxBig<FFT64>,
|
||||
col_b: usize,
|
||||
b_col: usize,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -209,43 +207,40 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
vec_znx_big::vec_znx_big_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t,
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(col_b, 0),
|
||||
b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t,
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) {
|
||||
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
|
||||
unsafe {
|
||||
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
|
||||
Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
|
||||
Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) {
|
||||
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
|
||||
unsafe {
|
||||
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
|
||||
Self::vec_znx_big_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a);
|
||||
Self::vec_znx_big_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_b(
|
||||
&self,
|
||||
res: &mut VecZnxBig<FFT64>,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<FFT64>,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
b: &VecZnx,
|
||||
col_b: usize,
|
||||
b_col: usize,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -255,36 +250,34 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
vec_znx_big::vec_znx_big_sub_small_b(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(col_b, 0),
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnx, col_a: usize) {
|
||||
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||
unsafe {
|
||||
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
|
||||
Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a);
|
||||
Self::vec_znx_big_sub_small_b(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a(
|
||||
&self,
|
||||
res: &mut VecZnxBig<FFT64>,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnx,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
b: &VecZnxBig<FFT64>,
|
||||
col_b: usize,
|
||||
b_col: usize,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -294,36 +287,34 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
vec_znx_big::vec_znx_big_sub_small_a(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(col_b, 0),
|
||||
b.at_ptr(b_col * b.size(), 0) as *const vec_znx_big_t,
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnx, col_a: usize) {
|
||||
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||
unsafe {
|
||||
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
|
||||
Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
|
||||
Self::vec_znx_big_sub_small_a(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small(
|
||||
&self,
|
||||
res: &mut VecZnxBig<FFT64>,
|
||||
col_res: usize,
|
||||
a: &VecZnx,
|
||||
col_a: usize,
|
||||
b: &VecZnxBig<FFT64>,
|
||||
col_b: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<FFT64>,
|
||||
a_col: usize,
|
||||
b: &VecZnx,
|
||||
b_col: usize,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -333,25 +324,23 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
vec_znx_big::vec_znx_big_add_small(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(col_b, 0),
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnx, a_col: usize) {
|
||||
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||
unsafe {
|
||||
let res_ptr: *mut VecZnxBig<FFT64> = res as *mut VecZnxBig<FFT64>;
|
||||
Self::vec_znx_big_add_small(self, &mut *res_ptr, col_res, a, a_col, &*res_ptr, col_res);
|
||||
Self::vec_znx_big_add_small(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,9 +352,9 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
res: &mut VecZnx,
|
||||
col_res: usize,
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<FFT64>,
|
||||
col_a: usize,
|
||||
a_col: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -376,44 +365,41 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
vec_znx_big::vec_znx_big_normalize_base2k(
|
||||
self.ptr,
|
||||
log_base2k as u64,
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<FFT64>, col_res: usize, a: &VecZnxBig<FFT64>, col_a: usize) {
|
||||
fn vec_znx_big_automorphism(&self, k: i64, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxBig<FFT64>, a_col: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
vec_znx_big::vec_znx_big_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(col_res, 0),
|
||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(col_a, 0),
|
||||
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<FFT64>, col_a: usize) {
|
||||
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<FFT64>, a_col: usize) {
|
||||
unsafe {
|
||||
let a_ptr: *mut VecZnxBig<FFT64> = a as *mut VecZnxBig<FFT64>;
|
||||
Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a);
|
||||
Self::vec_znx_big_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,129 +1,54 @@
|
||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
|
||||
use crate::{Backend, FFT64, Module, VecZnxBig, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement};
|
||||
use crate::{VecZnx, alloc_aligned};
|
||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||
use crate::{Backend, FFT64, Module, VecZnxBig};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const VEC_ZNX_DFT_ROWS: usize = 1;
|
||||
|
||||
pub struct VecZnxDft<B: Backend> {
|
||||
pub data: Vec<u8>,
|
||||
pub ptr: *mut u8,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub size: usize,
|
||||
inner: ZnxBase,
|
||||
pub _marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxBase<B> for VecZnxDft<B> {
|
||||
impl<B: Backend> GetZnxBase for VecZnxDft<B> {
|
||||
fn znx(&self) -> &ZnxBase {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
fn znx_mut(&mut self) -> &mut ZnxBase {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxInfos for VecZnxDft<B> {}
|
||||
|
||||
impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> {
|
||||
type Scalar = u8;
|
||||
|
||||
fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(size > 0);
|
||||
}
|
||||
let mut data: Vec<Self::Scalar> = alloc_aligned(Self::bytes_of(module, cols, size));
|
||||
let ptr: *mut Self::Scalar = data.as_mut_ptr();
|
||||
Self {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
n: module.n(),
|
||||
size: size,
|
||||
cols: cols,
|
||||
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
VecZnxDft {
|
||||
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
unsafe { bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided data as backing array.
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
|
||||
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
Self {
|
||||
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [Self::Scalar]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(cols > 0);
|
||||
assert!(size > 0);
|
||||
assert_eq!(bytes.len(), Self::bytes_of(module, cols, size));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
Self {
|
||||
data: Vec::new(),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
cols: cols,
|
||||
size: size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDft<B> {
|
||||
/// Cast a [VecZnxDft] into a [VecZnxBig].
|
||||
/// The returned [VecZnxBig] shares the backing array
|
||||
/// with the original [VecZnxDft].
|
||||
pub fn as_vec_znx_big(&mut self) -> VecZnxBig<B> {
|
||||
VecZnxBig::<B> {
|
||||
data: Vec::new(),
|
||||
ptr: self.ptr,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ZnxInfos for VecZnxDft<B> {
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
|
||||
debug_assert_eq!(
|
||||
_rows, VEC_ZNX_DFT_ROWS,
|
||||
"rows != {} not supported for VecZnxDft",
|
||||
VEC_ZNX_DFT_ROWS
|
||||
);
|
||||
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxLayout for VecZnxDft<FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
self.ptr as *const Self::Scalar
|
||||
}
|
||||
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
self.ptr as *mut Self::Scalar
|
||||
impl ZnxSliceSize for VecZnxDft<FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,225 +58,21 @@ impl VecZnxDft<FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftOps<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: the backing array is only borrowed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||
|
||||
/// b <- IDFT(a), uses a as scratch space.
|
||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig<B>, a: &mut VecZnxDft<B>);
|
||||
|
||||
fn vec_znx_idft(&self, b: &mut VecZnxBig<B>, a: &VecZnxDft<B>, tmp_bytes: &mut [u8]);
|
||||
|
||||
fn vec_znx_dft(&self, b: &mut VecZnxDft<B>, a: &VecZnx);
|
||||
|
||||
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft<B>, a: &VecZnxDft<B>);
|
||||
|
||||
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft<B>, tmp_bytes: &mut [u8]);
|
||||
|
||||
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::<FFT64>::new(&self, cols, size)
|
||||
impl<B: Backend> VecZnxDft<B> {
|
||||
/// Cast a [VecZnxDft] into a [VecZnxBig].
|
||||
/// The returned [VecZnxBig] shares the backing array
|
||||
/// with the original [VecZnxDft].
|
||||
pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> {
|
||||
VecZnxBig::<B> {
|
||||
inner: ZnxBase {
|
||||
data: Vec::new(),
|
||||
ptr: self.ptr(),
|
||||
n: self.n(),
|
||||
rows: self.rows(),
|
||||
cols: self.cols(),
|
||||
size: self.size(),
|
||||
},
|
||||
_marker: PhantomData,
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::from_bytes(self, cols, size, tmp_bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::from_bytes_borrow(self, cols, size, tmp_bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::bytes_of(&self, cols, size)
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig<FFT64>, a: &mut VecZnxDft<FFT64>) {
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.poly_count() as u64,
|
||||
a.ptr as *mut vec_znx_dft_t,
|
||||
a.poly_count() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
/// b <- DFT(a)
|
||||
///
|
||||
/// # Panics
|
||||
/// If b.cols < a_cols
|
||||
fn vec_znx_dft(&self, b: &mut VecZnxDft<FFT64>, a: &VecZnx) {
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
a.as_ptr(),
|
||||
a.size() as u64,
|
||||
(a.n() * a.cols()) as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
||||
fn vec_znx_idft(&self, b: &mut VecZnxBig<FFT64>, a: &VecZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
Self::vec_znx_idft_tmp_bytes(self)
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.poly_count() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.poly_count() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>) {
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_dft_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.poly_count() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.poly_count() as u64,
|
||||
[0u8; 0].as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
Self::vec_znx_dft_automorphism_tmp_bytes(self)
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
println!("{}", a.poly_count());
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_dft_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.ptr as *mut vec_znx_dft_t,
|
||||
a.poly_count() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.poly_count() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{FFT64, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxLayout, alloc_aligned};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn test_automorphism_dft() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
|
||||
let size: usize = 2;
|
||||
let log_base2k: usize = 17;
|
||||
let mut a: VecZnx = module.new_vec_znx(1, size);
|
||||
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
|
||||
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
module.fill_uniform(log_base2k, &mut a, 0, size, &mut source);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
|
||||
|
||||
let p: i64 = -5;
|
||||
|
||||
// a_dft <- DFT(a)
|
||||
module.vec_znx_dft(&mut a_dft, &a);
|
||||
|
||||
// a_dft <- AUTO(a_dft)
|
||||
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
|
||||
|
||||
// a <- AUTO(a)
|
||||
module.vec_znx_automorphism_inplace(p, &mut a, 0);
|
||||
|
||||
// b_dft <- DFT(AUTO(a))
|
||||
module.vec_znx_dft(&mut b_dft, &a);
|
||||
|
||||
let a_f64: &[f64] = a_dft.raw();
|
||||
let b_f64: &[f64] = b_dft.raw();
|
||||
izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| {
|
||||
assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs());
|
||||
});
|
||||
|
||||
module.free()
|
||||
}
|
||||
}
|
||||
|
||||
140
base2k/src/vec_znx_dft_ops.rs
Normal file
140
base2k/src/vec_znx_dft_ops.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use crate::ffi::vec_znx_big;
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::znx_base::ZnxAlloc;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::znx_base::ZnxLayout;
|
||||
use crate::znx_base::ZnxSliceSize;
|
||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, assert_alignement};
|
||||
|
||||
pub trait VecZnxDftOps<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: the backing array is only borrowed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||
|
||||
/// b <- IDFT(a), uses a as scratch space.
|
||||
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &mut VecZnxDft<B>, a_cols: usize);
|
||||
|
||||
fn vec_znx_idft(&self, res: &mut VecZnxBig<B>, res_col: usize, a: &VecZnxDft<B>, a_col: usize, tmp_bytes: &mut [u8]);
|
||||
|
||||
fn vec_znx_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &VecZnx, a_col: usize);
|
||||
}
|
||||
|
||||
impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::<FFT64>::new(&self, 1, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::from_bytes(self, 1, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<FFT64> {
|
||||
VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
||||
VecZnxDft::bytes_of(&self, 1, cols, size)
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &mut VecZnxDft<FFT64>, a_col: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.poly_count(), a.poly_count());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t,
|
||||
res.size() as u64,
|
||||
a.at_ptr(a_col * a.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
a.size() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
/// b <- DFT(a)
|
||||
///
|
||||
/// # Panics
|
||||
/// If b.cols < a_cols
|
||||
fn vec_znx_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
||||
fn vec_znx_idft(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &VecZnxDft<FFT64>, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
Self::vec_znx_idft_tmp_bytes(self)
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t,
|
||||
res.size() as u64,
|
||||
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
a.size() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::{Backend, Module, VecZnx, ZnxBase, ZnxInfos, ZnxLayout, assert_alignement, switch_degree};
|
||||
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
|
||||
use crate::{Backend, Module, VEC_ZNX_ROWS, VecZnx, assert_alignement};
|
||||
pub trait VecZnxOps {
|
||||
/// Allocates a new [VecZnx].
|
||||
///
|
||||
@@ -19,7 +20,7 @@ pub trait VecZnxOps {
|
||||
///
|
||||
/// # Panic
|
||||
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx;
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnx;
|
||||
|
||||
/// Instantiates a new [VecZnx] from a slice of bytes.
|
||||
/// The returned [VecZnx] does take ownership of the slice of bytes.
|
||||
@@ -107,19 +108,19 @@ pub trait VecZnxOps {
|
||||
|
||||
impl<B: Backend> VecZnxOps for Module<B> {
|
||||
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnx {
|
||||
VecZnx::new(self, cols, size)
|
||||
VecZnx::new(self, VEC_ZNX_ROWS, cols, size)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
|
||||
VecZnx::bytes_of(self, cols, size)
|
||||
VecZnx::bytes_of(self, VEC_ZNX_ROWS, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx {
|
||||
VecZnx::from_bytes(self, cols, size, bytes)
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnx {
|
||||
VecZnx::from_bytes(self, VEC_ZNX_ROWS, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx {
|
||||
VecZnx::from_bytes_borrow(self, cols, size, tmp_bytes)
|
||||
VecZnx::from_bytes_borrow(self, VEC_ZNX_ROWS, cols, size, tmp_bytes)
|
||||
}
|
||||
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
|
||||
|
||||
@@ -1,10 +1,37 @@
|
||||
use crate::{Backend, Module, assert_alignement, cast_mut};
|
||||
use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut};
|
||||
use itertools::izip;
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait ZnxInfos {
|
||||
pub struct ZnxBase {
|
||||
/// The ring degree
|
||||
pub n: usize,
|
||||
|
||||
/// The number of rows (in the third dimension)
|
||||
pub rows: usize,
|
||||
|
||||
/// The number of polynomials
|
||||
pub cols: usize,
|
||||
|
||||
/// The number of size per polynomial (a.k.a small polynomials).
|
||||
pub size: usize,
|
||||
|
||||
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
|
||||
pub data: Vec<u8>,
|
||||
|
||||
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
|
||||
pub ptr: *mut u8,
|
||||
}
|
||||
|
||||
pub trait GetZnxBase {
|
||||
fn znx(&self) -> &ZnxBase;
|
||||
fn znx_mut(&mut self) -> &mut ZnxBase;
|
||||
}
|
||||
|
||||
pub trait ZnxInfos: GetZnxBase {
|
||||
/// Returns the ring degree of the polynomials.
|
||||
fn n(&self) -> usize;
|
||||
fn n(&self) -> usize {
|
||||
self.znx().n
|
||||
}
|
||||
|
||||
/// Returns the base two logarithm of the ring dimension of the polynomials.
|
||||
fn log_n(&self) -> usize {
|
||||
@@ -12,41 +39,104 @@ pub trait ZnxInfos {
|
||||
}
|
||||
|
||||
/// Returns the number of rows.
|
||||
fn rows(&self) -> usize;
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.znx().rows
|
||||
}
|
||||
/// Returns the number of polynomials in each row.
|
||||
fn cols(&self) -> usize;
|
||||
fn cols(&self) -> usize {
|
||||
self.znx().cols
|
||||
}
|
||||
|
||||
/// Returns the number of size per polynomial.
|
||||
fn size(&self) -> usize;
|
||||
fn size(&self) -> usize {
|
||||
self.znx().size
|
||||
}
|
||||
|
||||
fn data(&self) -> &[u8] {
|
||||
&self.znx().data
|
||||
}
|
||||
|
||||
fn ptr(&self) -> *mut u8 {
|
||||
self.znx().ptr
|
||||
}
|
||||
|
||||
/// Returns the total number of small polynomials.
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows() * self.cols() * self.size()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxSliceSize {
|
||||
/// Returns the slice size, which is the offset between
|
||||
/// two size of the same column.
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
fn sl(&self) -> usize;
|
||||
}
|
||||
|
||||
impl ZnxBase {
|
||||
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
||||
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
|
||||
res.data = bytes;
|
||||
res
|
||||
}
|
||||
|
||||
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(n & (n - 1), 0, "n must be a power of two");
|
||||
assert!(n > 0, "n must be greater than 0");
|
||||
assert!(rows > 0, "rows must be greater than 0");
|
||||
assert!(cols > 0, "cols must be greater than 0");
|
||||
assert!(size > 0, "size must be greater than 0");
|
||||
}
|
||||
Self {
|
||||
n: n,
|
||||
rows: rows,
|
||||
cols: cols,
|
||||
size: size,
|
||||
data: Vec::new(),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxBase<B: Backend> {
|
||||
pub trait ZnxAlloc<B: Backend>
|
||||
where
|
||||
Self: Sized + ZnxInfos,
|
||||
{
|
||||
type Scalar;
|
||||
fn new(module: &Module<B>, cols: usize, size: usize) -> Self;
|
||||
fn from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
|
||||
fn from_bytes_borrow(module: &Module<B>, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
|
||||
fn bytes_of(module: &Module<B>, cols: usize, size: usize) -> usize;
|
||||
fn new(module: &Module<B>, rows: usize, cols: usize, size: usize) -> Self {
|
||||
let bytes: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(module, rows, cols, size));
|
||||
Self::from_bytes(module, rows, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn from_bytes(module: &Module<B>, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
||||
let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes);
|
||||
res.znx_mut().data = bytes;
|
||||
res
|
||||
}
|
||||
|
||||
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self;
|
||||
|
||||
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait ZnxLayout: ZnxInfos {
|
||||
type Scalar;
|
||||
|
||||
/// Returns true if the receiver is only borrowing the data.
|
||||
fn borrowing(&self) -> bool {
|
||||
self.znx().data.len() == 0
|
||||
}
|
||||
|
||||
/// Returns a non-mutable pointer to the underlying coefficients array.
|
||||
fn as_ptr(&self) -> *const Self::Scalar;
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
self.znx().ptr as *const Self::Scalar
|
||||
}
|
||||
|
||||
/// Returns a mutable pointer to the underlying coefficients array.
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar;
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
self.znx_mut().ptr as *mut Self::Scalar
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference to the entire underlying coefficient array.
|
||||
fn raw(&self) -> &[Self::Scalar] {
|
||||
@@ -1,5 +1,3 @@
|
||||
cargo-features = ["edition2024"]
|
||||
|
||||
[package]
|
||||
name = "rlwe"
|
||||
version = "0.1.0"
|
||||
|
||||
@@ -20,7 +20,7 @@ pub struct AutomorphismKey {
|
||||
}
|
||||
|
||||
pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||
module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||
module.bytes_of_scalar() + module.bytes_of_scalar_znx_dft() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
@@ -103,10 +103,10 @@ impl AutomorphismKey {
|
||||
tmp_bytes: &mut [u8],
|
||||
) -> Vec<Self> {
|
||||
let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar());
|
||||
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol());
|
||||
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar_znx_dft());
|
||||
|
||||
let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes);
|
||||
let mut sk_out: ScalarZnxDft = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes);
|
||||
let mut sk_out: ScalarZnxDft = module.new_scalar_znx_dft_from_bytes_borrow(sk_out_bytes);
|
||||
|
||||
let mut keys: Vec<AutomorphismKey> = Vec::new();
|
||||
|
||||
@@ -116,7 +116,7 @@ impl AutomorphismKey {
|
||||
let p_inv: i64 = module.galois_element_inv(*pi);
|
||||
|
||||
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
|
||||
module.svp_prepare(&mut sk_out, &sk_auto);
|
||||
module.scalar_znx_dft_prepare(&mut sk_out, &sk_auto);
|
||||
encrypt_grlwe_sk(
|
||||
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
|
||||
);
|
||||
|
||||
@@ -20,7 +20,7 @@ impl SecretKey {
|
||||
}
|
||||
|
||||
pub fn prepare(&self, module: &Module, sk_ppol: &mut ScalarZnxDft) {
|
||||
module.svp_prepare(sk_ppol, &self.0)
|
||||
module.scalar_znx_dft_prepare(sk_ppol, &self.0)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user