Applied discussed changes, everything working, but still to discuss

This commit is contained in:
Jean-Philippe Bossuat
2025-05-01 10:33:19 +02:00
parent 4e6fce3458
commit ca5e6d46c9
14 changed files with 710 additions and 508 deletions

View File

@@ -4,5 +4,8 @@
"plaintext": false, "plaintext": false,
"markdown": false, "markdown": false,
"scminput": false "scminput": false
},
"files.associations": {
"random": "c"
} }
} }

View File

@@ -13,7 +13,8 @@ fn main() {
let log_scale: usize = msg_size * log_base2k - 5; let log_scale: usize = msg_size * log_base2k - 5;
let module: Module<FFT64> = Module::<FFT64>::new(n); let module: Module<FFT64> = Module::<FFT64>::new(n);
let mut carry: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); let mut tmp_bytes_norm: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes());
let mut tmp_bytes_dft = alloc_aligned(module.bytes_of_vec_znx_dft(1, ct_size));
let seed: [u8; 32] = [0; 32]; let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed); let mut source: Source = Source::new(seed);
@@ -38,9 +39,10 @@ fn main() {
module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source);
// Scratch space for DFT values // Scratch space for DFT values
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft( let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft_from_bytes_borrow(
1, // Number of columns 1, // Number of columns
ct.size(), // Number of polynomials per column ct.size(), // Number of polynomials per column
&mut tmp_bytes_dft,
); );
// Applies DFT(ct[1]) * DFT(s) // Applies DFT(ct[1]) * DFT(s)
@@ -68,7 +70,7 @@ fn main() {
want.iter_mut() want.iter_mut()
.for_each(|x| *x = source.next_u64n(16, 15) as i64); .for_each(|x| *x = source.next_u64n(16, 15) as i64);
m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); m.encode_vec_i64(0, log_base2k, log_scale, &want, 4);
m.normalize(log_base2k, &mut carry); m.normalize(log_base2k, 0, &mut tmp_bytes_norm);
// m - BIG(ct[1] * s) // m - BIG(ct[1] * s)
module.vec_znx_big_sub_small_a_inplace( module.vec_znx_big_sub_small_a_inplace(
@@ -81,9 +83,12 @@ fn main() {
// Normalizes back to VecZnx // Normalizes back to VecZnx
// ct[0] <- m - BIG(c1 * s) // ct[0] <- m - BIG(c1 * s)
module.vec_znx_big_normalize( module.vec_znx_big_normalize(
log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) log_base2k,
&buf_big, 0, // Selects the first column of buf_big &mut ct,
&mut carry, 0, // Selects the first column of ct (ct[0])
&buf_big,
0, // Selects the first column of buf_big
&mut tmp_bytes_norm,
); );
// Add noise to ct[0] // Add noise to ct[0]
@@ -120,7 +125,7 @@ fn main() {
// m + e <- BIG(ct[1] * s + ct[0]) // m + e <- BIG(ct[1] * s + ct[0])
let mut res: VecZnx = module.new_vec_znx(1, ct_size); let mut res: VecZnx = module.new_vec_znx(1, ct_size);
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut carry); module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut tmp_bytes_norm);
// have = m * 2^{log_scale} + e // have = m * 2^{log_scale} + e
let mut have: Vec<i64> = vec![i64::default(); n]; let mut have: Vec<i64> = vec![i64::default(); n];

View File

@@ -1,59 +0,0 @@
use base2k::{
Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
ZnxInfos, ZnxLayout, alloc_aligned,
};
fn main() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 15;
let limbs_vec: usize = 5;
let log_k: usize = log_base2k * limbs_vec - 5;
let rows_mat: usize = limbs_vec;
let limbs_mat: usize = limbs_vec + 1;
// Maximum size of the byte scratch needed
let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows_mat, 1, limbs_mat)
| module.vmp_apply_dft_tmp_bytes(limbs_vec, limbs_vec, rows_mat, limbs_mat);
let mut buf: Vec<u8> = alloc_aligned(tmp_bytes);
let mut a_values: Vec<i64> = vec![i64::default(); n];
a_values[1] = (1 << log_base2k) + 1;
let mut a: VecZnx = module.new_vec_znx(1, limbs_vec);
a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32);
a.normalize(log_base2k, &mut buf);
a.print(n);
println!();
let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(rows_mat, 1, limbs_mat);
(0..a.size()).for_each(|row_i| {
let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat);
tmp.at_limb_mut(row_i)[1] = 1 as i64;
module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf);
});
let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs_mat);
module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf);
let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
module.vec_znx_idft_tmp_a(&mut c_big, 0, &mut c_dft, 0);
let mut res: VecZnx = module.new_vec_znx(1, limbs_vec);
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf);
let mut values_res: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(0, log_base2k, log_k, &mut values_res);
res.print(n);
module.free();
println!("{:?}", values_res)
}

78
base2k/examples/vmp.rs Normal file
View File

@@ -0,0 +1,78 @@
use base2k::{
Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
ZnxInfos, ZnxLayout, alloc_aligned,
};
fn main() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 15;
let a_cols: usize = 2;
let a_size: usize = 5;
let log_k: usize = log_base2k * a_size - 5;
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
let mat_cols_out: usize = 2;
let mat_size: usize = a_size + 1;
let mut tmp_bytes_vmp: Vec<u8> = alloc_aligned(
module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size)
| module.vmp_apply_dft_tmp_bytes(
a_size,
a_size,
mat_rows,
mat_cols_in,
mat_cols_out,
mat_size,
),
);
let mut tmp_bytes_dft: Vec<u8> = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size));
let mut a: VecZnx = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|i| {
let mut values: Vec<i64> = vec![i64::default(); n];
values[1 + i] = (1 << log_base2k) + 1;
a.encode_vec_i64(i, log_base2k, log_k, &values, 32);
a.normalize(log_base2k, i, &mut tmp_bytes_vmp);
a.print(n, i);
println!();
});
let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
(0..a.size()).for_each(|row_i| {
let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size);
(0..mat_cols_out).for_each(|j| {
tmp.at_mut(j, row_i)[1 + j] = 1 as i64;
});
(0..mat_cols_in).for_each(|j| {
module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp);
})
});
let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft);
module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp);
let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size);
let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
(0..mat_cols_out).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp);
let mut values_res: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(i, log_base2k, log_k, &mut values_res);
res.print(n, i);
println!();
println!("{:?}", values_res);
println!();
});
module.free();
}

View File

@@ -1,4 +1,4 @@
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, alloc_aligned}; use crate::{Backend, FFT64, Module, alloc_aligned};
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -10,6 +10,8 @@ use std::marker::PhantomData;
/// See the trait [MatZnxDftOps] for additional information. /// See the trait [MatZnxDftOps] for additional information.
pub struct MatZnxDft<B: Backend> { pub struct MatZnxDft<B: Backend> {
pub inner: ZnxBase, pub inner: ZnxBase,
pub cols_in: usize,
pub cols_out: usize,
_marker: PhantomData<B>, _marker: PhantomData<B>,
} }
@@ -35,18 +37,54 @@ impl ZnxLayout for MatZnxDft<FFT64> {
type Scalar = f64; type Scalar = f64;
} }
impl<B: Backend> ZnxAlloc<B> for MatZnxDft<B> { impl<B: Backend> MatZnxDft<B> {
type Scalar = u8; pub fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let bytes: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
Self::from_bytes(module, rows, cols_in, cols_out, size, bytes)
}
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { pub fn from_bytes(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut mat: MatZnxDft<B> = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes);
mat.znx_mut().data = bytes;
mat
}
pub fn from_bytes_borrow(
module: &Module<B>,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: &mut [u8],
) -> Self {
debug_assert_eq!(
bytes.len(),
Self::bytes_of(module, rows, cols_in, cols_out, size)
);
Self { Self {
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes), inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes),
cols_in: cols_in,
cols_out: cols_out,
_marker: PhantomData, _marker: PhantomData,
} }
} }
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize { pub fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
unsafe { crate::ffi::vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols } unsafe {
crate::ffi::vmp::bytes_of_vmp_pmat(
module.ptr,
(rows * cols_in) as u64,
(size * cols_out) as u64,
) as usize
}
}
pub fn cols_in(&self) -> usize {
self.cols_in
}
pub fn cols_out(&self) -> usize {
self.cols_out
} }
} }

View File

@@ -1,8 +1,9 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp; use crate::ffi::vmp;
use crate::znx_base::{ZnxInfos, ZnxLayout}; use crate::znx_base::{ZnxInfos, ZnxLayout};
use crate::{Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxAlloc, assert_alignement}; use crate::{
Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, assert_alignement, is_aligned,
};
/// This trait implements methods for vector matrix product, /// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [MatZnxDft]. /// that is, multiplying a [VecZnx] with a [MatZnxDft].
@@ -13,44 +14,45 @@ pub trait MatZnxDftOps<B: Backend> {
/// ///
/// * `rows`: number of rows (number of [VecZnxDft]). /// * `rows`: number of rows (number of [VecZnxDft]).
/// * `size`: number of size (number of size of each [VecZnxDft]). /// * `size`: number of size (number of size of each [VecZnxDft]).
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>; fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft<B>;
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize; fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec<u8>) -> MatZnxDft<FFT64>; fn new_mat_znx_dft_from_bytes(
&self,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxDft<FFT64>;
fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft<FFT64>; fn new_mat_znx_dft_from_bytes_borrow(
&self,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: &mut [u8],
) -> MatZnxDft<FFT64>;
/// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous]. /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row]
/// fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
/// # Arguments
///
/// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
/// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize;
/// Prepares a [MatZnxDft] from a contiguous array of [i64].
/// The helper struct [Matrix3D] can be used to contruct and populate
/// the appropriate contiguous array.
///
/// # Arguments
///
/// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft].
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft<B>, a: &[i64], buf: &mut [u8]);
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. /// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `b`: [MatZnxDft] on which the values are encoded. /// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft]. /// * `row_i`: the row of the [MatZnxDft] to prepare.
/// * `row_i`: the index of the row to prepare. /// * `a`: the [VecZnx] to encode on the i-th row of the [MatZnxDft].
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
/// ///
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row(&self, b: &mut MatZnxDft<B>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]); fn vmp_prepare_row(&self, b: &mut MatZnxDft<B>, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]);
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row]
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
/// ///
@@ -59,7 +61,15 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft]. /// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded. /// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract. /// * `row_i`: the index of the row to extract.
fn vmp_extract_row(&self, b: &mut VecZnxBig<B>, a: &MatZnxDft<B>, row_i: usize); fn vmp_extract_row(
&self,
log_base2k: usize,
b: &mut VecZnx,
a: &MatZnxDft<B>,
b_row: usize,
b_col_in: usize,
tmp_bytes: &mut [u8],
);
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
/// ///
@@ -70,7 +80,7 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `row_i`: the index of the row to prepare. /// * `row_i`: the index of the row to prepare.
/// ///
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<B>, a: &VecZnxDft<B>, row_i: usize); fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<B>, b_row: usize, b_col_in: usize, a: &VecZnxDft<B>);
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
/// ///
@@ -79,7 +89,7 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded. /// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract. /// * `row_i`: the index of the row to extract.
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, row_i: usize, a: &MatZnxDft<B>); fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, a_row: usize, a_col_in: usize);
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
/// ///
@@ -89,7 +99,15 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `a_size`: number of size of the input [VecZnx]. /// * `a_size`: number of size of the input [VecZnx].
/// * `rows`: number of rows of the input [MatZnxDft]. /// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft]. /// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize; fn vmp_apply_dft_tmp_bytes(
&self,
c_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// ///
@@ -117,32 +135,6 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]); fn vmp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver.
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
/// ///
/// # Arguments /// # Arguments
@@ -151,7 +143,17 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `a_size`: number of size of the input [VecZnxDft]. /// * `a_size`: number of size of the input [VecZnxDft].
/// * `rows`: number of rows of the input [MatZnxDft]. /// * `rows`: number of rows of the input [MatZnxDft].
/// * `size`: number of size of the input [MatZnxDft]. /// * `size`: number of size of the input [MatZnxDft].
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize; fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
c_cols: usize,
c_size: usize,
a_cols: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
@@ -179,308 +181,385 @@ pub trait MatZnxDftOps<B: Backend> {
/// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]); fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it.
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place.
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: the input and output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, buf: &mut [u8]);
} }
impl MatZnxDftOps<FFT64> for Module<FFT64> { impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> { fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft<FFT64> {
MatZnxDft::<FFT64>::new(self, rows, cols, size) MatZnxDft::<FFT64>::new(self, rows, cols_in, cols_out, size)
} }
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize { fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
MatZnxDft::<FFT64>::bytes_of(self, rows, cols, size) MatZnxDft::<FFT64>::bytes_of(self, rows, cols_in, cols_out, size)
} }
fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec<u8>) -> MatZnxDft<FFT64> { fn new_mat_znx_dft_from_bytes(
MatZnxDft::<FFT64>::from_bytes(self, rows, cols, size, bytes) &self,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxDft<FFT64> {
MatZnxDft::<FFT64>::from_bytes(self, rows, cols_in, cols_out, size, bytes)
} }
fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft<FFT64> { fn new_mat_znx_dft_from_bytes_borrow(
MatZnxDft::<FFT64>::from_bytes_borrow(self, rows, cols, size, bytes) &self,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: &mut [u8],
) -> MatZnxDft<FFT64> {
MatZnxDft::<FFT64>::from_bytes_borrow(self, rows, cols_in, cols_out, size, bytes)
} }
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize { fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize } self.bytes_of_vec_znx_dft(cols_out, size)
} }
fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft<FFT64>, a: &[i64], tmp_bytes: &mut [u8]) { fn vmp_prepare_row(&self, b: &mut MatZnxDft<FFT64>, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.len(), b.n() * b.poly_count()); assert_eq!(b.n(), self.n());
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size())); assert_eq!(a.n(), self.n());
assert_alignement(tmp_bytes.as_ptr()); assert_eq!(
a.cols(),
b.cols_out(),
"a.cols(): {} != b.cols_out(): {}",
a.cols(),
b.cols_out()
);
assert!(
b_row < b.rows(),
"b_row: {} >= b.rows(): {}",
b_row,
b.rows()
);
assert!(
b_col_in < b.cols_in(),
"b_col_in: {} >= b.cols_in(): {}",
b_col_in,
b.cols_in()
);
assert_eq!(
b.size(),
a.size(),
"b.size(): {} != a.size(): {}",
b.size(),
a.size()
);
assert!(tmp_bytes.len() >= self.vmp_prepare_row_tmp_bytes(a.cols(), a.size()));
assert!(is_aligned(tmp_bytes.as_ptr()))
} }
unsafe {
vmp::vmp_prepare_contiguous( let cols_out: usize = a.cols();
self.ptr, let a_size: usize = a.size();
b.as_mut_ptr() as *mut vmp::vmp_pmat_t,
a.as_ptr(), let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size));
b.rows() as u64,
(b.size() * b.cols()) as u64, let mut a_dft: VecZnxDft<FFT64> = self.new_vec_znx_dft_from_bytes_borrow(cols_out, a_size, tmp_bytes_a_dft);
tmp_bytes.as_mut_ptr(), (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i));
Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft);
}
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
self.bytes_of_vec_znx_dft(cols_out, size) + self.vec_znx_big_normalize_tmp_bytes()
}
fn vmp_extract_row(
&self,
log_base2k: usize,
b: &mut VecZnx,
a: &MatZnxDft<FFT64>,
a_row: usize,
a_col_in: usize,
tmp_bytes: &mut [u8],
) {
#[cfg(debug_assertions)]
{
assert_eq!(b.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
b.cols(),
a.cols_out(),
"b.cols(): {} != a.cols_out(): {}",
b.cols(),
a.cols_out()
);
assert!(
a_row < a.rows(),
"a_row: {} >= a.rows(): {}",
a_row,
a.rows()
);
assert!(
a_col_in < a.cols_in(),
"a_col_in: {} >= a.cols_in(): {}",
a_col_in,
a.cols_in()
);
assert_eq!(
b.size(),
a.size(),
"b.size(): {} != a.size(): {}",
b.size(),
a.size()
);
assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size()));
assert!(is_aligned(tmp_bytes.as_ptr()))
}
let cols_out: usize = b.cols();
let size: usize = b.size();
let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size));
let mut b_dft: VecZnxDft<FFT64> = self.new_vec_znx_dft_from_bytes_borrow(cols_out, size, bytes_a_dft);
Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in);
let mut b_big: VecZnxBig<FFT64> = b_dft.alias_as_vec_znx_big();
(0..cols_out).for_each(|i| {
self.vec_znx_idft_tmp_a(&mut b_big, i, &mut b_dft, i);
self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, tmp_bytes);
});
}
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<FFT64>, b_row: usize, b_col_in: usize, a: &VecZnxDft<FFT64>) {
#[cfg(debug_assertions)]
{
assert_eq!(b.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
a.cols(),
b.cols_out(),
"a.cols(): {} != b.cols_out(): {}",
a.cols(),
b.cols_out()
);
assert!(
b_row < b.rows(),
"b_row: {} >= b.rows(): {}",
b_row,
b.rows()
);
assert!(
b_col_in < b.cols_in(),
"b_col_in: {} >= b.cols_in(): {}",
b_col_in,
b.cols_in()
);
assert_eq!(
b.size(),
a.size(),
"b.size(): {} != a.size(): {}",
b.size(),
a.size()
); );
} }
}
fn vmp_prepare_row(&self, b: &mut MatZnxDft<FFT64>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), b.size() * self.n() * b.cols());
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_prepare_row(
self.ptr,
b.as_mut_ptr() as *mut vmp::vmp_pmat_t,
a.as_ptr(),
row_i as u64,
b.rows() as u64,
(b.size() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vmp_extract_row(&self, b: &mut VecZnxBig<FFT64>, a: &MatZnxDft<FFT64>, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.size(), b.size());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_extract_row(
self.ptr,
b.as_mut_ptr() as *mut vec_znx_big_t,
a.as_ptr() as *const vmp::vmp_pmat_t,
row_i as u64,
a.rows() as u64,
(a.size() * a.cols()) as u64,
);
}
}
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<FFT64>, a: &VecZnxDft<FFT64>, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.size(), b.size());
}
unsafe { unsafe {
vmp::vmp_prepare_row_dft( vmp::vmp_prepare_row_dft(
self.ptr, self.ptr,
b.as_mut_ptr() as *mut vmp::vmp_pmat_t, b.as_mut_ptr() as *mut vmp::vmp_pmat_t,
a.as_ptr() as *const vec_znx_dft_t, a.as_ptr() as *const vec_znx_dft_t,
row_i as u64, (b_row * b.cols_in() + b_col_in) as u64,
b.rows() as u64, (b.rows() * b.cols_in()) as u64,
b.size() as u64, (b.size() * b.cols_out()) as u64,
); );
} }
} }
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, row_i: usize, a: &MatZnxDft<FFT64>) { fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, a_row: usize, a_col_in: usize) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(a.n(), b.n()); assert_eq!(b.n(), self.n());
assert_eq!(a.size(), b.size()); assert_eq!(a.n(), self.n());
assert_eq!(
b.cols(),
a.cols_out(),
"b.cols(): {} != a.cols_out(): {}",
b.cols(),
a.cols_out()
);
assert!(
a_row < a.rows(),
"a_row: {} >= a.rows(): {}",
a_row,
a.rows()
);
assert!(
a_col_in < a.cols_in(),
"a_col_in: {} >= a.cols_in(): {}",
a_col_in,
a.cols_in()
);
assert_eq!(
b.size(),
a.size(),
"b.size(): {} != a.size(): {}",
b.size(),
a.size()
);
} }
unsafe { unsafe {
vmp::vmp_extract_row_dft( vmp::vmp_extract_row_dft(
self.ptr, self.ptr,
b.as_mut_ptr() as *mut vec_znx_dft_t, b.as_mut_ptr() as *mut vec_znx_dft_t,
a.as_ptr() as *const vmp::vmp_pmat_t, a.as_ptr() as *const vmp::vmp_pmat_t,
row_i as u64, (a_row * a.cols_in() + a_col_in) as u64,
a.rows() as u64, (a.rows() * a.cols_in()) as u64,
a.size() as u64, (a.size() * a.cols_out()) as u64,
); );
} }
} }
fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize { fn vmp_apply_dft_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
unsafe { unsafe {
vmp::vmp_apply_dft_tmp_bytes( vmp::vmp_apply_dft_tmp_bytes(
self.ptr, self.ptr,
res_size as u64, res_size as u64,
a_size as u64, a_size as u64,
b_rows as u64, (b_rows * b_cols_in) as u64,
b_size as u64, (b_size * b_cols_out) as u64,
) as usize ) as usize
} }
} }
fn vmp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) { fn vmp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); debug_assert!(
tmp_bytes.len()
>= self.vmp_apply_dft_tmp_bytes(
c.size(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size()
)
);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(c.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
c.cols(),
b.cols_out(),
"c.cols(): {} != b.cols_out: {}",
c.cols(),
b.cols_out()
);
assert_eq!(
a.cols(),
b.cols_in(),
"a.cols(): {} != b.cols_in: {}",
a.cols(),
b.cols_in()
);
assert!(
tmp_bytes.len()
>= self.vmp_apply_dft_tmp_bytes(
c.size(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size()
)
);
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
} }
unsafe { unsafe {
vmp::vmp_apply_dft( vmp::vmp_apply_dft(
self.ptr, self.ptr,
c.as_mut_ptr() as *mut vec_znx_dft_t, c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64, (c.size() * c.cols()) as u64,
a.as_ptr(), a.as_ptr(),
a.size() as u64, (a.size() * a.cols()) as u64,
(a.n() * a.cols()) as u64, a.n() as u64,
b.as_ptr() as *const vmp::vmp_pmat_t, b.as_ptr() as *const vmp::vmp_pmat_t,
b.rows() as u64, (b.rows() * b.cols_in()) as u64,
b.size() as u64, (b.size() * b.cols_out()) as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
} }
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) { fn vmp_apply_dft_to_dft_tmp_bytes(
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size())); &self,
#[cfg(debug_assertions)] res_cols: usize,
{ res_size: usize,
assert_alignement(tmp_bytes.as_ptr()); a_size: usize,
} a_cols: usize,
unsafe { b_rows: usize,
vmp::vmp_apply_dft_add( b_cols_in: usize,
self.ptr, b_cols_out: usize,
c.as_mut_ptr() as *mut vec_znx_dft_t, b_size: usize,
c.size() as u64, ) -> usize {
a.as_ptr(),
a.size() as u64,
(a.n() * a.size()) as u64,
b.as_ptr() as *const vmp::vmp_pmat_t,
b.rows() as u64,
b.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize {
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes( vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.ptr, self.ptr,
res_size as u64, (res_size * res_cols) as u64,
a_size as u64, (a_size * a_cols) as u64,
gct_rows as u64, (b_rows * b_cols_in) as u64,
gct_size as u64, (b_size * b_cols_out) as u64,
) as usize ) as usize
} }
} }
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) { fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(c.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
c.cols(),
b.cols_out(),
"c.cols(): {} != b.cols_out: {}",
c.cols(),
b.cols_out()
);
assert_eq!(
a.cols(),
b.cols_in(),
"a.cols(): {} != b.cols_in: {}",
a.cols(),
b.cols_in()
);
assert!(
tmp_bytes.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(
c.cols(),
c.size(),
a.cols(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size()
)
);
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
} }
unsafe { unsafe {
vmp::vmp_apply_dft_to_dft( vmp::vmp_apply_dft_to_dft(
self.ptr, self.ptr,
c.as_mut_ptr() as *mut vec_znx_dft_t, c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64, c.poly_count() as u64,
a.as_ptr() as *const vec_znx_dft_t, a.as_ptr() as *const vec_znx_dft_t,
a.size() as u64, a.poly_count() as u64,
b.as_ptr() as *const vmp::vmp_pmat_t, b.as_ptr() as *const vmp::vmp_pmat_t,
b.rows() as u64, b.rows() as u64,
b.size() as u64, (b.size() * b.cols()) as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_add(
&self,
c: &mut VecZnxDft<FFT64>,
a: &VecZnxDft<FFT64>,
b: &MatZnxDft<FFT64>,
tmp_bytes: &mut [u8],
) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_to_dft_add(
self.ptr,
c.as_mut_ptr() as *mut vec_znx_dft_t,
c.size() as u64,
a.as_ptr() as *const vec_znx_dft_t,
a.size() as u64,
b.as_ptr() as *const vmp::vmp_pmat_t,
b.rows() as u64,
b.size() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_to_dft(
self.ptr,
b.as_mut_ptr() as *mut vec_znx_dft_t,
b.size() as u64,
b.as_ptr() as *mut vec_znx_dft_t,
b.size() as u64,
a.as_ptr() as *const vmp::vmp_pmat_t,
a.rows() as u64,
a.size() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
@@ -497,38 +576,52 @@ mod tests {
#[test] #[test]
fn vmp_prepare_row_dft() { fn vmp_prepare_row_dft() {
let module: Module<FFT64> = Module::<FFT64>::new(32); let module: Module<FFT64> = Module::<FFT64>::new(16);
let vpmat_rows: usize = 4;
let vpmat_size: usize = 5;
let log_base2k: usize = 8; let log_base2k: usize = 8;
let mut a: VecZnx = module.new_vec_znx(1, vpmat_size); let mat_rows: usize = 4;
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, vpmat_size); let mat_cols_in: usize = 2;
let mut a_big: VecZnxBig<FFT64> = module.new_vec_znx_big(1, vpmat_size); let mat_cols_out: usize = 2;
let mut b_big: VecZnxBig<FFT64> = module.new_vec_znx_big(1, vpmat_size); let mat_size: usize = 5;
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, vpmat_size); let mut a: VecZnx = module.new_vec_znx(mat_cols_out, mat_size);
let mut vmpmat_0: MatZnxDft<FFT64> = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); let mut b: VecZnx = module.new_vec_znx(mat_cols_out, mat_size);
let mut vmpmat_1: MatZnxDft<FFT64> = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size); let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut a_big: VecZnxBig<FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut vmpmat_0: MatZnxDft<FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
let mut vmpmat_1: MatZnxDft<FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size)); let mut tmp_bytes: Vec<u8> =
alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes());
for row_i in 0..vpmat_rows { for col_in in 0..mat_cols_in {
let mut source: Source = Source::new([0u8; 32]); for row_i in 0..mat_rows {
module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source); let mut source: Source = Source::new([0u8; 32]);
module.vec_znx_dft(&mut a_dft, 0, &a, 0);
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) (0..mat_cols_out).for_each(|col_out| {
module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i); module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source);
assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
});
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut tmp_bytes);
module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0);
assert_eq!(a_dft.raw(), b_dft.raw());
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i); module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft);
module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes); assert_eq!(vmpmat_0.raw(), vmpmat_1.raw());
assert_eq!(a_big.raw(), b_big.raw());
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft)
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i, col_in);
assert_eq!(a_dft.raw(), b_dft.raw());
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut tmp_bytes);
(0..mat_cols_out).for_each(|col_out| {
module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes);
module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut tmp_bytes);
});
assert_eq!(a.raw(), b.raw());
}
} }
module.free(); module.free();

View File

@@ -28,6 +28,7 @@ impl<B: Backend> ZnxAlloc<B> for ScalarZnxDft<B> {
type Scalar = u8; type Scalar = u8;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size));
Self { Self {
inner: ZnxBase::from_bytes_borrow( inner: ZnxBase::from_bytes_borrow(
module.n(), module.n(),
@@ -61,6 +62,6 @@ impl ZnxLayout for ScalarZnxDft<FFT64> {
impl ZnxSliceSize for ScalarZnxDft<FFT64> { impl ZnxSliceSize for ScalarZnxDft<FFT64> {
fn sl(&self) -> usize { fn sl(&self) -> usize {
self.n() self.n() * self.cols()
} }
} }

View File

@@ -3,7 +3,7 @@ use crate::Module;
use crate::assert_alignement; use crate::assert_alignement;
use crate::cast_mut; use crate::cast_mut;
use crate::ffi::znx; use crate::ffi::znx;
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree}; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree};
use std::cmp::min; use std::cmp::min;
pub const VEC_ZNX_ROWS: usize = 1; pub const VEC_ZNX_ROWS: usize = 1;
@@ -44,7 +44,9 @@ impl ZnxLayout for VecZnx {
type Scalar = i64; type Scalar = i64;
} }
impl ZnxBasics for VecZnx {} impl ZnxZero for VecZnx {}
impl ZnxRsh for VecZnx {}
impl<B: Backend> ZnxAlloc<B> for VecZnx { impl<B: Backend> ZnxAlloc<B> for VecZnx {
type Scalar = i64; type Scalar = i64;
@@ -84,7 +86,7 @@ impl VecZnx {
/// ///
/// * `log_base2k`: the base two logarithm of the coefficients decomposition. /// * `log_base2k`: the base two logarithm of the coefficients decomposition.
/// * `k`: the number of bits of precision to drop. /// * `k`: the number of bits of precision to drop.
pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize) { pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize, col: usize) {
if k == 0 { if k == 0 {
return; return;
} }
@@ -101,7 +103,7 @@ impl VecZnx {
if k_rem != 0 { if k_rem != 0 {
let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem; let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem;
self.at_limb_mut(self.size() - 1) self.at_mut(col, self.size() - 1)
.iter_mut() .iter_mut()
.for_each(|x: &mut i64| *x &= mask) .for_each(|x: &mut i64| *x &= mask)
} }
@@ -111,8 +113,8 @@ impl VecZnx {
copy_vec_znx_from(self, a); copy_vec_znx_from(self, a);
} }
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) { pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) {
normalize(log_base2k, self, carry) normalize(log_base2k, self, col, carry)
} }
pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) { pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) {
@@ -120,26 +122,25 @@ impl VecZnx {
} }
// Prints the first `n` coefficients of each limb // Prints the first `n` coefficients of each limb
pub fn print(&self, n: usize) { pub fn print(&self, n: usize, col: usize) {
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])) (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n]));
} }
} }
fn normalize_tmp_bytes(n: usize, size: usize) -> usize { fn normalize_tmp_bytes(n: usize) -> usize {
n * size * std::mem::size_of::<i64>() n * std::mem::size_of::<i64>()
} }
fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) { fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) {
let n: usize = a.n(); let n: usize = a.n();
let cols: usize = a.cols();
debug_assert!( debug_assert!(
tmp_bytes.len() >= normalize_tmp_bytes(n, cols), tmp_bytes.len() >= normalize_tmp_bytes(n),
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})", "invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})",
tmp_bytes.len(), tmp_bytes.len(),
n, n,
cols,
); );
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_alignement(tmp_bytes.as_ptr()) assert_alignement(tmp_bytes.as_ptr())
@@ -151,11 +152,11 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr()); znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
(0..a.size()).rev().for_each(|i| { (0..a.size()).rev().for_each(|i| {
znx::znx_normalize( znx::znx_normalize(
(n * cols) as u64, n as u64,
log_base2k as u64, log_base2k as u64,
a.at_mut_ptr(0, i), a.at_mut_ptr(a_col, i),
carry_i64.as_mut_ptr(), carry_i64.as_mut_ptr(),
a.at_mut_ptr(0, i), a.at_mut_ptr(a_col, i),
carry_i64.as_mut_ptr(), carry_i64.as_mut_ptr(),
) )
}); });

View File

@@ -1,5 +1,5 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize}; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero};
use crate::{Backend, FFT64, Module, NTT120}; use crate::{Backend, FFT64, Module, NTT120};
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -26,6 +26,7 @@ impl<B: Backend> ZnxAlloc<B> for VecZnxBig<B> {
type Scalar = u8; type Scalar = u8;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
VecZnxBig { VecZnxBig {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes), inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes),
_marker: PhantomData, _marker: PhantomData,
@@ -50,24 +51,24 @@ impl ZnxLayout for VecZnxBig<NTT120> {
type Scalar = i128; type Scalar = i128;
} }
impl ZnxBasics for VecZnxBig<FFT64> {} impl ZnxZero for VecZnxBig<FFT64> {}
impl ZnxSliceSize for VecZnxBig<FFT64> { impl ZnxSliceSize for VecZnxBig<FFT64> {
fn sl(&self) -> usize { fn sl(&self) -> usize {
self.n() self.n() * self.cols()
} }
} }
impl ZnxSliceSize for VecZnxBig<NTT120> { impl ZnxSliceSize for VecZnxBig<NTT120> {
fn sl(&self) -> usize { fn sl(&self) -> usize {
self.n() * 4 self.n() * 4 * self.cols()
} }
} }
impl ZnxBasics for VecZnxBig<NTT120> {} impl ZnxZero for VecZnxBig<NTT120> {}
impl VecZnxBig<FFT64> { impl VecZnxBig<FFT64> {
pub fn print(&self, n: usize) { pub fn print(&self, n: usize, col: usize) {
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
} }
} }

View File

@@ -1,4 +1,4 @@
use crate::ffi::vec_znx_big::{self, vec_znx_big_t}; use crate::ffi::vec_znx;
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize}; use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize};
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement}; use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement};
@@ -171,14 +171,17 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx_big::vec_znx_big_add( vec_znx::vec_znx_add(
self.ptr, self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.at_mut_ptr(res_col, 0),
res.size() as u64, res.size() as u64,
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64, a.size() as u64,
b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64, b.size() as u64,
b.sl() as u64,
) )
} }
} }
@@ -207,14 +210,17 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx_big::vec_znx_big_sub( vec_znx::vec_znx_sub(
self.ptr, self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.at_mut_ptr(res_col, 0),
res.size() as u64, res.size() as u64,
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t, res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64, a.size() as u64,
b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t, a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64, b.size() as u64,
b.sl() as u64,
) )
} }
} }
@@ -250,12 +256,14 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx_big::vec_znx_big_sub_small_b( vec_znx::vec_znx_sub(
self.ptr, self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.at_mut_ptr(res_col, 0),
res.size() as u64, res.size() as u64,
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0), b.at_ptr(b_col, 0),
b.size() as u64, b.size() as u64,
b.sl() as u64, b.sl() as u64,
@@ -287,15 +295,17 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx_big::vec_znx_big_sub_small_a( vec_znx::vec_znx_sub(
self.ptr, self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.at_mut_ptr(res_col, 0),
res.size() as u64, res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0), a.at_ptr(a_col, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64, a.sl() as u64,
b.at_ptr(b_col * b.size(), 0) as *const vec_znx_big_t, b.at_ptr(b_col, 0),
b.size() as u64, b.size() as u64,
b.sl() as u64,
) )
} }
} }
@@ -324,12 +334,14 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_ne!(a.as_ptr(), b.as_ptr()); assert_ne!(a.as_ptr(), b.as_ptr());
} }
unsafe { unsafe {
vec_znx_big::vec_znx_big_add_small( vec_znx::vec_znx_add(
self.ptr, self.ptr,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.at_mut_ptr(res_col, 0),
res.size() as u64, res.size() as u64,
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0), b.at_ptr(b_col, 0),
b.size() as u64, b.size() as u64,
b.sl() as u64, b.sl() as u64,
@@ -365,14 +377,15 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_alignement(tmp_bytes.as_ptr()); assert_alignement(tmp_bytes.as_ptr());
} }
unsafe { unsafe {
vec_znx_big::vec_znx_big_normalize_base2k( vec_znx::vec_znx_normalize_base2k(
self.ptr, self.ptr,
log_base2k as u64, log_base2k as u64,
res.at_mut_ptr(res_col, 0), res.at_mut_ptr(res_col, 0),
res.size() as u64, res.size() as u64,
res.sl() as u64, res.sl() as u64,
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, a.at_ptr(a_col, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
); );
} }
@@ -385,13 +398,15 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
assert_eq!(res.n(), self.n()); assert_eq!(res.n(), self.n());
} }
unsafe { unsafe {
vec_znx_big::vec_znx_big_automorphism( vec_znx::vec_znx_automorphism(
self.ptr, self.ptr,
k, k,
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t, res.at_mut_ptr(res_col, 0),
res.size() as u64, res.size() as u64,
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t, res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64,
) )
} }
} }

View File

@@ -1,5 +1,5 @@
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize}; use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero};
use crate::{Backend, FFT64, Module, VecZnxBig}; use crate::{Backend, FFT64, Module, VecZnxBig};
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -26,6 +26,7 @@ impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> {
type Scalar = u8; type Scalar = u8;
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
Self { Self {
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
_marker: PhantomData, _marker: PhantomData,
@@ -46,6 +47,8 @@ impl ZnxLayout for VecZnxDft<FFT64> {
type Scalar = f64; type Scalar = f64;
} }
impl ZnxZero for VecZnxDft<FFT64> {}
impl ZnxSliceSize for VecZnxDft<FFT64> { impl ZnxSliceSize for VecZnxDft<FFT64> {
fn sl(&self) -> usize { fn sl(&self) -> usize {
self.n() self.n()
@@ -53,8 +56,8 @@ impl ZnxSliceSize for VecZnxDft<FFT64> {
} }
impl VecZnxDft<FFT64> { impl VecZnxDft<FFT64> {
pub fn print(&self, n: usize) { pub fn print(&self, n: usize, col: usize) {
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n])); (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
} }
} }
@@ -63,6 +66,10 @@ impl<B: Backend> VecZnxDft<B> {
/// The returned [VecZnxBig] shares the backing array /// The returned [VecZnxBig] shares the backing array
/// with the original [VecZnxDft]. /// with the original [VecZnxDft].
pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> { pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> {
assert!(
self.data().len() == 0,
"cannot alias VecZnxDft into VecZnxBig if it owns the data"
);
VecZnxBig::<B> { VecZnxBig::<B> {
inner: ZnxBase { inner: ZnxBase {
data: Vec::new(), data: Vec::new(),

View File

@@ -4,7 +4,8 @@ use crate::znx_base::ZnxAlloc;
use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxInfos;
use crate::znx_base::ZnxLayout; use crate::znx_base::ZnxLayout;
use crate::znx_base::ZnxSliceSize; use crate::znx_base::ZnxSliceSize;
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, assert_alignement}; use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxZero, assert_alignement};
use std::cmp::min;
pub trait VecZnxDftOps<B: Backend> { pub trait VecZnxDftOps<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. /// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
@@ -77,19 +78,21 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
} }
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &mut VecZnxDft<FFT64>, a_col: usize) { fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &mut VecZnxDft<FFT64>, a_col: usize) {
#[cfg(debug_assertions)] let min_size: usize = min(res.size(), a.size());
{
assert_eq!(res.poly_count(), a.poly_count());
}
unsafe { unsafe {
vec_znx_dft::vec_znx_idft_tmp_a( (0..min_size).for_each(|j| {
self.ptr, vec_znx_dft::vec_znx_idft_tmp_a(
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, self.ptr,
res.size() as u64, res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
a.at_ptr(a_col * a.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, 1 as u64,
a.size() as u64, a.at_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
) 1 as u64,
)
});
(min_size..res.size()).for_each(|j| {
res.zero_at(res_col, j);
})
} }
} }
@@ -102,15 +105,22 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
/// # Panics /// # Panics
/// If b.cols < a_cols /// If b.cols < a_cols
fn vec_znx_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) { fn vec_znx_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
let min_size: usize = min(res.size(), a.size());
unsafe { unsafe {
vec_znx_dft::vec_znx_dft( (0..min_size).for_each(|j| {
self.ptr, vec_znx_dft::vec_znx_dft(
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t, self.ptr,
res.size() as u64, res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
a.at_ptr(a_col, 0), 1 as u64,
a.size() as u64, a.at_ptr(a_col, j),
a.sl() as u64, 1 as u64,
) a.sl() as u64,
)
});
(min_size..res.size()).for_each(|j| {
res.zero_at(res_col, j);
});
} }
} }
@@ -126,15 +136,23 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
); );
assert_alignement(tmp_bytes.as_ptr()) assert_alignement(tmp_bytes.as_ptr())
} }
let min_size: usize = min(res.size(), a.size());
unsafe { unsafe {
vec_znx_dft::vec_znx_idft( (0..min_size).for_each(|j| {
self.ptr, vec_znx_dft::vec_znx_idft(
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t, self.ptr,
res.size() as u64, res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_dft::vec_znx_dft_t, 1 as u64,
a.size() as u64, a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
tmp_bytes.as_mut_ptr(), 1 as u64,
) tmp_bytes.as_mut_ptr(),
)
});
(min_size..res.size()).for_each(|j| {
res.zero_at(res_col, j);
});
} }
} }
} }

View File

@@ -22,6 +22,33 @@ pub struct ZnxBase {
pub ptr: *mut u8, pub ptr: *mut u8,
} }
impl ZnxBase {
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
res.data = bytes;
res
}
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(n & (n - 1), 0, "n must be a power of two");
assert!(n > 0, "n must be greater than 0");
assert!(rows > 0, "rows must be greater than 0");
assert!(cols > 0, "cols must be greater than 0");
assert!(size > 0, "size must be greater than 0");
}
Self {
n: n,
rows: rows,
cols: cols,
size: size,
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
}
}
}
pub trait GetZnxBase { pub trait GetZnxBase {
fn znx(&self) -> &ZnxBase; fn znx(&self) -> &ZnxBase;
fn znx_mut(&mut self) -> &mut ZnxBase; fn znx_mut(&mut self) -> &mut ZnxBase;
@@ -52,10 +79,12 @@ pub trait ZnxInfos: GetZnxBase {
self.znx().size self.znx().size
} }
/// Returns the underlying raw bytes array.
fn data(&self) -> &[u8] { fn data(&self) -> &[u8] {
&self.znx().data &self.znx().data
} }
/// Returns a pointer to the underlying raw bytes array.
fn ptr(&self) -> *mut u8 { fn ptr(&self) -> *mut u8 {
self.znx().ptr self.znx().ptr
} }
@@ -72,33 +101,6 @@ pub trait ZnxSliceSize {
fn sl(&self) -> usize; fn sl(&self) -> usize;
} }
impl ZnxBase {
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
res.data = bytes;
res
}
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(n & (n - 1), 0, "n must be a power of two");
assert!(n > 0, "n must be greater than 0");
assert!(rows > 0, "rows must be greater than 0");
assert!(cols > 0, "cols must be greater than 0");
assert!(size > 0, "size must be greater than 0");
}
Self {
n: n,
rows: rows,
cols: cols,
size: size,
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
}
}
}
pub trait ZnxAlloc<B: Backend> pub trait ZnxAlloc<B: Backend>
where where
Self: Sized + ZnxInfos, Self: Sized + ZnxInfos,
@@ -148,25 +150,25 @@ pub trait ZnxLayout: ZnxInfos {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
} }
/// Returns a non-mutable pointer starting at the (i, j)-th small polynomial. /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(i < self.cols()); assert!(i < self.cols());
assert!(j < self.size()); assert!(j < self.size());
} }
let offset = self.n() * (j * self.cols() + i); let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_ptr().add(offset) } unsafe { self.as_ptr().add(offset) }
} }
/// Returns a mutable pointer starting at the (i, j)-th small polynomial. /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!(i < self.cols()); assert!(i < self.cols());
assert!(j < self.size()); assert!(j < self.size());
} }
let offset = self.n() * (j * self.cols() + i); let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_mut_ptr().add(offset) } unsafe { self.as_mut_ptr().add(offset) }
} }
@@ -179,16 +181,6 @@ pub trait ZnxLayout: ZnxInfos {
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
} }
/// Returns non-mutable reference to the i-th limb.
fn at_limb(&self, j: usize) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.at_ptr(0, j), self.n() * self.cols()) }
}
/// Returns mutable reference to the i-th limb.
fn at_limb_mut(&mut self, j: usize) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) }
}
} }
use std::convert::TryFrom; use std::convert::TryFrom;
@@ -221,14 +213,17 @@ impl IntegerType for i128 {
const BITS: u32 = 128; const BITS: u32 = 128;
} }
pub trait ZnxBasics: ZnxLayout pub trait ZnxZero: ZnxLayout
where where
Self: Sized, Self: Sized,
Self::Scalar: IntegerType,
{ {
fn zero(&mut self) { fn zero(&mut self) {
unsafe { unsafe {
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * size_of::<Self::Scalar>()); std::ptr::write_bytes(
self.as_mut_ptr(),
0,
self.n() * size_of::<Self::Scalar>() * self.poly_count(),
);
} }
} }
@@ -241,13 +236,19 @@ where
); );
} }
} }
}
fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) { pub trait ZnxRsh: ZnxLayout + ZnxZero
rsh(log_base2k, self, k, carry) where
Self: Sized,
Self::Scalar: IntegerType,
{
fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
rsh(k, log_base2k, self, col, carry)
} }
} }
pub fn rsh<V: ZnxBasics>(log_base2k: usize, a: &mut V, k: usize, tmp_bytes: &mut [u8]) pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8])
where where
V::Scalar: IntegerType, V::Scalar: IntegerType,
{ {
@@ -258,7 +259,7 @@ where
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert!( assert!(
tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n, cols), tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n),
"invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
tmp_bytes.len() / size_of::<V::Scalar>(), tmp_bytes.len() / size_of::<V::Scalar>(),
n, n,
@@ -291,7 +292,7 @@ where
let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
(steps..size).for_each(|i| { (steps..size).for_each(|i| {
izip!(carry.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| { izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| {
*xi += *ci << log_base2k_t; *xi += *ci << log_base2k_t;
*ci = get_base_k_carry(*xi, shift); *ci = get_base_k_carry(*xi, shift);
*xi = (*xi - *ci) >> k_rem_t; *xi = (*xi - *ci) >> k_rem_t;
@@ -305,11 +306,11 @@ fn get_base_k_carry<T: IntegerType>(x: T, shift: T) -> T {
(x << shift) >> shift (x << shift) >> shift
} }
pub fn rsh_tmp_bytes<T: IntegerType>(n: usize, cols: usize) -> usize { pub fn rsh_tmp_bytes<T: IntegerType>(n: usize) -> usize {
n * cols * std::mem::size_of::<T>() n * std::mem::size_of::<T>()
} }
pub fn switch_degree<T: ZnxLayout + ZnxBasics>(b: &mut T, col_b: usize, a: &T, col_a: usize) pub fn switch_degree<T: ZnxLayout + ZnxZero>(b: &mut T, col_b: usize, a: &T, col_a: usize)
where where
<T as ZnxLayout>::Scalar: IntegerType, <T as ZnxLayout>::Scalar: IntegerType,
{ {