mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Applied discussed changes, everything working, but still to discuss
This commit is contained in:
3
base2k/.vscode/settings.json
vendored
3
base2k/.vscode/settings.json
vendored
@@ -4,5 +4,8 @@
|
|||||||
"plaintext": false,
|
"plaintext": false,
|
||||||
"markdown": false,
|
"markdown": false,
|
||||||
"scminput": false
|
"scminput": false
|
||||||
|
},
|
||||||
|
"files.associations": {
|
||||||
|
"random": "c"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -13,7 +13,8 @@ fn main() {
|
|||||||
let log_scale: usize = msg_size * log_base2k - 5;
|
let log_scale: usize = msg_size * log_base2k - 5;
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||||
|
|
||||||
let mut carry: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes());
|
let mut tmp_bytes_norm: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes());
|
||||||
|
let mut tmp_bytes_dft = alloc_aligned(module.bytes_of_vec_znx_dft(1, ct_size));
|
||||||
|
|
||||||
let seed: [u8; 32] = [0; 32];
|
let seed: [u8; 32] = [0; 32];
|
||||||
let mut source: Source = Source::new(seed);
|
let mut source: Source = Source::new(seed);
|
||||||
@@ -38,9 +39,10 @@ fn main() {
|
|||||||
module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source);
|
module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source);
|
||||||
|
|
||||||
// Scratch space for DFT values
|
// Scratch space for DFT values
|
||||||
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(
|
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft_from_bytes_borrow(
|
||||||
1, // Number of columns
|
1, // Number of columns
|
||||||
ct.size(), // Number of polynomials per column
|
ct.size(), // Number of polynomials per column
|
||||||
|
&mut tmp_bytes_dft,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Applies DFT(ct[1]) * DFT(s)
|
// Applies DFT(ct[1]) * DFT(s)
|
||||||
@@ -68,7 +70,7 @@ fn main() {
|
|||||||
want.iter_mut()
|
want.iter_mut()
|
||||||
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
||||||
m.encode_vec_i64(0, log_base2k, log_scale, &want, 4);
|
m.encode_vec_i64(0, log_base2k, log_scale, &want, 4);
|
||||||
m.normalize(log_base2k, &mut carry);
|
m.normalize(log_base2k, 0, &mut tmp_bytes_norm);
|
||||||
|
|
||||||
// m - BIG(ct[1] * s)
|
// m - BIG(ct[1] * s)
|
||||||
module.vec_znx_big_sub_small_a_inplace(
|
module.vec_znx_big_sub_small_a_inplace(
|
||||||
@@ -81,9 +83,12 @@ fn main() {
|
|||||||
// Normalizes back to VecZnx
|
// Normalizes back to VecZnx
|
||||||
// ct[0] <- m - BIG(c1 * s)
|
// ct[0] <- m - BIG(c1 * s)
|
||||||
module.vec_znx_big_normalize(
|
module.vec_znx_big_normalize(
|
||||||
log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0])
|
log_base2k,
|
||||||
&buf_big, 0, // Selects the first column of buf_big
|
&mut ct,
|
||||||
&mut carry,
|
0, // Selects the first column of ct (ct[0])
|
||||||
|
&buf_big,
|
||||||
|
0, // Selects the first column of buf_big
|
||||||
|
&mut tmp_bytes_norm,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add noise to ct[0]
|
// Add noise to ct[0]
|
||||||
@@ -120,7 +125,7 @@ fn main() {
|
|||||||
|
|
||||||
// m + e <- BIG(ct[1] * s + ct[0])
|
// m + e <- BIG(ct[1] * s + ct[0])
|
||||||
let mut res: VecZnx = module.new_vec_znx(1, ct_size);
|
let mut res: VecZnx = module.new_vec_znx(1, ct_size);
|
||||||
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut carry);
|
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut tmp_bytes_norm);
|
||||||
|
|
||||||
// have = m * 2^{log_scale} + e
|
// have = m * 2^{log_scale} + e
|
||||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
use base2k::{
|
|
||||||
Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
|
|
||||||
ZnxInfos, ZnxLayout, alloc_aligned,
|
|
||||||
};
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
let log_n: i32 = 5;
|
|
||||||
let n: usize = 1 << log_n;
|
|
||||||
|
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
|
||||||
let log_base2k: usize = 15;
|
|
||||||
let limbs_vec: usize = 5;
|
|
||||||
let log_k: usize = log_base2k * limbs_vec - 5;
|
|
||||||
|
|
||||||
let rows_mat: usize = limbs_vec;
|
|
||||||
let limbs_mat: usize = limbs_vec + 1;
|
|
||||||
|
|
||||||
// Maximum size of the byte scratch needed
|
|
||||||
let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows_mat, 1, limbs_mat)
|
|
||||||
| module.vmp_apply_dft_tmp_bytes(limbs_vec, limbs_vec, rows_mat, limbs_mat);
|
|
||||||
|
|
||||||
let mut buf: Vec<u8> = alloc_aligned(tmp_bytes);
|
|
||||||
|
|
||||||
let mut a_values: Vec<i64> = vec![i64::default(); n];
|
|
||||||
a_values[1] = (1 << log_base2k) + 1;
|
|
||||||
|
|
||||||
let mut a: VecZnx = module.new_vec_znx(1, limbs_vec);
|
|
||||||
a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32);
|
|
||||||
a.normalize(log_base2k, &mut buf);
|
|
||||||
|
|
||||||
a.print(n);
|
|
||||||
println!();
|
|
||||||
|
|
||||||
let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(rows_mat, 1, limbs_mat);
|
|
||||||
|
|
||||||
(0..a.size()).for_each(|row_i| {
|
|
||||||
let mut tmp: VecZnx = module.new_vec_znx(1, limbs_mat);
|
|
||||||
tmp.at_limb_mut(row_i)[1] = 1 as i64;
|
|
||||||
module.vmp_prepare_row(&mut mat_znx_dft, tmp.raw(), row_i, &mut buf);
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, limbs_mat);
|
|
||||||
module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut buf);
|
|
||||||
|
|
||||||
let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
|
|
||||||
module.vec_znx_idft_tmp_a(&mut c_big, 0, &mut c_dft, 0);
|
|
||||||
|
|
||||||
let mut res: VecZnx = module.new_vec_znx(1, limbs_vec);
|
|
||||||
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &c_big, 0, &mut buf);
|
|
||||||
|
|
||||||
let mut values_res: Vec<i64> = vec![i64::default(); n];
|
|
||||||
res.decode_vec_i64(0, log_base2k, log_k, &mut values_res);
|
|
||||||
|
|
||||||
res.print(n);
|
|
||||||
|
|
||||||
module.free();
|
|
||||||
|
|
||||||
println!("{:?}", values_res)
|
|
||||||
}
|
|
||||||
78
base2k/examples/vmp.rs
Normal file
78
base2k/examples/vmp.rs
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
use base2k::{
|
||||||
|
Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
|
||||||
|
ZnxInfos, ZnxLayout, alloc_aligned,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let log_n: i32 = 5;
|
||||||
|
let n: usize = 1 << log_n;
|
||||||
|
|
||||||
|
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||||
|
let log_base2k: usize = 15;
|
||||||
|
|
||||||
|
let a_cols: usize = 2;
|
||||||
|
let a_size: usize = 5;
|
||||||
|
|
||||||
|
let log_k: usize = log_base2k * a_size - 5;
|
||||||
|
|
||||||
|
let mat_rows: usize = a_size;
|
||||||
|
let mat_cols_in: usize = a_cols;
|
||||||
|
let mat_cols_out: usize = 2;
|
||||||
|
let mat_size: usize = a_size + 1;
|
||||||
|
|
||||||
|
let mut tmp_bytes_vmp: Vec<u8> = alloc_aligned(
|
||||||
|
module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size)
|
||||||
|
| module.vmp_apply_dft_tmp_bytes(
|
||||||
|
a_size,
|
||||||
|
a_size,
|
||||||
|
mat_rows,
|
||||||
|
mat_cols_in,
|
||||||
|
mat_cols_out,
|
||||||
|
mat_size,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut tmp_bytes_dft: Vec<u8> = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size));
|
||||||
|
|
||||||
|
let mut a: VecZnx = module.new_vec_znx(a_cols, a_size);
|
||||||
|
|
||||||
|
(0..a_cols).for_each(|i| {
|
||||||
|
let mut values: Vec<i64> = vec![i64::default(); n];
|
||||||
|
values[1 + i] = (1 << log_base2k) + 1;
|
||||||
|
a.encode_vec_i64(i, log_base2k, log_k, &values, 32);
|
||||||
|
a.normalize(log_base2k, i, &mut tmp_bytes_vmp);
|
||||||
|
a.print(n, i);
|
||||||
|
println!();
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||||
|
|
||||||
|
(0..a.size()).for_each(|row_i| {
|
||||||
|
let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size);
|
||||||
|
(0..mat_cols_out).for_each(|j| {
|
||||||
|
tmp.at_mut(j, row_i)[1 + j] = 1 as i64;
|
||||||
|
});
|
||||||
|
(0..mat_cols_in).for_each(|j| {
|
||||||
|
module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp);
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft);
|
||||||
|
module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp);
|
||||||
|
|
||||||
|
let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size);
|
||||||
|
let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
|
||||||
|
(0..mat_cols_out).for_each(|i| {
|
||||||
|
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||||
|
module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp);
|
||||||
|
|
||||||
|
let mut values_res: Vec<i64> = vec![i64::default(); n];
|
||||||
|
res.decode_vec_i64(i, log_base2k, log_k, &mut values_res);
|
||||||
|
res.print(n, i);
|
||||||
|
println!();
|
||||||
|
println!("{:?}", values_res);
|
||||||
|
println!();
|
||||||
|
});
|
||||||
|
|
||||||
|
module.free();
|
||||||
|
}
|
||||||
Submodule base2k/spqlios-arithmetic updated: e3d3247335...8135d85e7a
@@ -1,4 +1,4 @@
|
|||||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||||
use crate::{Backend, FFT64, Module, alloc_aligned};
|
use crate::{Backend, FFT64, Module, alloc_aligned};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
@@ -10,6 +10,8 @@ use std::marker::PhantomData;
|
|||||||
/// See the trait [MatZnxDftOps] for additional information.
|
/// See the trait [MatZnxDftOps] for additional information.
|
||||||
pub struct MatZnxDft<B: Backend> {
|
pub struct MatZnxDft<B: Backend> {
|
||||||
pub inner: ZnxBase,
|
pub inner: ZnxBase,
|
||||||
|
pub cols_in: usize,
|
||||||
|
pub cols_out: usize,
|
||||||
_marker: PhantomData<B>,
|
_marker: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,18 +37,54 @@ impl ZnxLayout for MatZnxDft<FFT64> {
|
|||||||
type Scalar = f64;
|
type Scalar = f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ZnxAlloc<B> for MatZnxDft<B> {
|
impl<B: Backend> MatZnxDft<B> {
|
||||||
type Scalar = u8;
|
pub fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||||
|
let bytes: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||||
|
Self::from_bytes(module, rows, cols_in, cols_out, size, bytes)
|
||||||
|
}
|
||||||
|
|
||||||
fn from_bytes_borrow(module: &Module<B>, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
pub fn from_bytes(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
||||||
|
let mut mat: MatZnxDft<B> = Self::from_bytes_borrow(module, rows, cols_in, cols_out, size, &mut bytes);
|
||||||
|
mat.znx_mut().data = bytes;
|
||||||
|
mat
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bytes_borrow(
|
||||||
|
module: &Module<B>,
|
||||||
|
rows: usize,
|
||||||
|
cols_in: usize,
|
||||||
|
cols_out: usize,
|
||||||
|
size: usize,
|
||||||
|
bytes: &mut [u8],
|
||||||
|
) -> Self {
|
||||||
|
debug_assert_eq!(
|
||||||
|
bytes.len(),
|
||||||
|
Self::bytes_of(module, rows, cols_in, cols_out, size)
|
||||||
|
);
|
||||||
Self {
|
Self {
|
||||||
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols, size, bytes),
|
inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes),
|
||||||
|
cols_in: cols_in,
|
||||||
|
cols_out: cols_out,
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bytes_of(module: &Module<B>, rows: usize, cols: usize, size: usize) -> usize {
|
pub fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||||
unsafe { crate::ffi::vmp::bytes_of_vmp_pmat(module.ptr, rows as u64, size as u64) as usize * cols }
|
unsafe {
|
||||||
|
crate::ffi::vmp::bytes_of_vmp_pmat(
|
||||||
|
module.ptr,
|
||||||
|
(rows * cols_in) as u64,
|
||||||
|
(size * cols_out) as u64,
|
||||||
|
) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cols_in(&self) -> usize {
|
||||||
|
self.cols_in
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cols_out(&self) -> usize {
|
||||||
|
self.cols_out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
|
||||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||||
use crate::ffi::vmp;
|
use crate::ffi::vmp;
|
||||||
use crate::znx_base::{ZnxInfos, ZnxLayout};
|
use crate::znx_base::{ZnxInfos, ZnxLayout};
|
||||||
use crate::{Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxAlloc, assert_alignement};
|
use crate::{
|
||||||
|
Backend, FFT64, MatZnxDft, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, assert_alignement, is_aligned,
|
||||||
|
};
|
||||||
|
|
||||||
/// This trait implements methods for vector matrix product,
|
/// This trait implements methods for vector matrix product,
|
||||||
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
||||||
@@ -13,44 +14,45 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
///
|
///
|
||||||
/// * `rows`: number of rows (number of [VecZnxDft]).
|
/// * `rows`: number of rows (number of [VecZnxDft]).
|
||||||
/// * `size`: number of size (number of size of each [VecZnxDft]).
|
/// * `size`: number of size (number of size of each [VecZnxDft]).
|
||||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<B>;
|
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft<B>;
|
||||||
|
|
||||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize;
|
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||||
|
|
||||||
fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec<u8>) -> MatZnxDft<FFT64>;
|
fn new_mat_znx_dft_from_bytes(
|
||||||
|
&self,
|
||||||
|
rows: usize,
|
||||||
|
cols_in: usize,
|
||||||
|
cols_out: usize,
|
||||||
|
size: usize,
|
||||||
|
bytes: Vec<u8>,
|
||||||
|
) -> MatZnxDft<FFT64>;
|
||||||
|
|
||||||
fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft<FFT64>;
|
fn new_mat_znx_dft_from_bytes_borrow(
|
||||||
|
&self,
|
||||||
|
rows: usize,
|
||||||
|
cols_in: usize,
|
||||||
|
cols_out: usize,
|
||||||
|
size: usize,
|
||||||
|
bytes: &mut [u8],
|
||||||
|
) -> MatZnxDft<FFT64>;
|
||||||
|
|
||||||
/// Returns the number of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_contiguous].
|
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row]
|
||||||
///
|
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `rows`: number of rows of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
|
|
||||||
/// * `size`: number of size of the [MatZnxDft] used in [MatZnxDftOps::vmp_prepare_contiguous].
|
|
||||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize;
|
|
||||||
|
|
||||||
/// Prepares a [MatZnxDft] from a contiguous array of [i64].
|
|
||||||
/// The helper struct [Matrix3D] can be used to contruct and populate
|
|
||||||
/// the appropriate contiguous array.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `b`: [MatZnxDft] on which the values are encoded.
|
|
||||||
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [MatZnxDft].
|
|
||||||
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
|
||||||
fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft<B>, a: &[i64], buf: &mut [u8]);
|
|
||||||
|
|
||||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
|
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `b`: [MatZnxDft] on which the values are encoded.
|
/// * `b`: [MatZnxDft] on which the values are encoded.
|
||||||
/// * `a`: the vector of [VecZnx] to encode on the [MatZnxDft].
|
/// * `row_i`: the row of the [MatZnxDft] to prepare.
|
||||||
/// * `row_i`: the index of the row to prepare.
|
/// * `a`: the [VecZnx] to encode on the i-th row of the [MatZnxDft].
|
||||||
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||||
///
|
///
|
||||||
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||||
fn vmp_prepare_row(&self, b: &mut MatZnxDft<B>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
|
fn vmp_prepare_row(&self, b: &mut MatZnxDft<B>, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]);
|
||||||
|
|
||||||
|
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row]
|
||||||
|
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
||||||
|
|
||||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
|
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
|
||||||
///
|
///
|
||||||
@@ -59,7 +61,15 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
/// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft].
|
/// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft].
|
||||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
/// * `a`: [MatZnxDft] on which the values are encoded.
|
||||||
/// * `row_i`: the index of the row to extract.
|
/// * `row_i`: the index of the row to extract.
|
||||||
fn vmp_extract_row(&self, b: &mut VecZnxBig<B>, a: &MatZnxDft<B>, row_i: usize);
|
fn vmp_extract_row(
|
||||||
|
&self,
|
||||||
|
log_base2k: usize,
|
||||||
|
b: &mut VecZnx,
|
||||||
|
a: &MatZnxDft<B>,
|
||||||
|
b_row: usize,
|
||||||
|
b_col_in: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
);
|
||||||
|
|
||||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
||||||
///
|
///
|
||||||
@@ -70,7 +80,7 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
/// * `row_i`: the index of the row to prepare.
|
/// * `row_i`: the index of the row to prepare.
|
||||||
///
|
///
|
||||||
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||||
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<B>, a: &VecZnxDft<B>, row_i: usize);
|
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<B>, b_row: usize, b_col_in: usize, a: &VecZnxDft<B>);
|
||||||
|
|
||||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
|
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
|
||||||
///
|
///
|
||||||
@@ -79,7 +89,7 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
||||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
/// * `a`: [MatZnxDft] on which the values are encoded.
|
||||||
/// * `row_i`: the index of the row to extract.
|
/// * `row_i`: the index of the row to extract.
|
||||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, row_i: usize, a: &MatZnxDft<B>);
|
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, a_row: usize, a_col_in: usize);
|
||||||
|
|
||||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
||||||
///
|
///
|
||||||
@@ -89,7 +99,15 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
/// * `a_size`: number of size of the input [VecZnx].
|
/// * `a_size`: number of size of the input [VecZnx].
|
||||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||||
/// * `size`: number of size of the input [MatZnxDft].
|
/// * `size`: number of size of the input [MatZnxDft].
|
||||||
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize;
|
fn vmp_apply_dft_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
c_size: usize,
|
||||||
|
a_size: usize,
|
||||||
|
b_rows: usize,
|
||||||
|
b_cols_in: usize,
|
||||||
|
b_cols_out: usize,
|
||||||
|
b_size: usize,
|
||||||
|
) -> usize;
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||||
///
|
///
|
||||||
@@ -117,32 +135,6 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
||||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
|
fn vmp_apply_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on the receiver.
|
|
||||||
///
|
|
||||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
|
||||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
|
||||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
|
||||||
///
|
|
||||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
|
||||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
|
||||||
///
|
|
||||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
|
||||||
///
|
|
||||||
/// ```text
|
|
||||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
|
||||||
/// |h i j|
|
|
||||||
/// |k l m|
|
|
||||||
/// ```
|
|
||||||
/// where each element is a [VecZnxDft].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
|
|
||||||
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
|
||||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
|
||||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
|
||||||
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnx, b: &MatZnxDft<B>, buf: &mut [u8]);
|
|
||||||
|
|
||||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -151,7 +143,17 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
/// * `a_size`: number of size of the input [VecZnxDft].
|
/// * `a_size`: number of size of the input [VecZnxDft].
|
||||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||||
/// * `size`: number of size of the input [MatZnxDft].
|
/// * `size`: number of size of the input [MatZnxDft].
|
||||||
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
|
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
c_cols: usize,
|
||||||
|
c_size: usize,
|
||||||
|
a_cols: usize,
|
||||||
|
a_size: usize,
|
||||||
|
b_rows: usize,
|
||||||
|
b_cols_in: usize,
|
||||||
|
b_cols_out: usize,
|
||||||
|
b_size: usize,
|
||||||
|
) -> usize;
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
@@ -179,308 +181,385 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
|
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] and adds on top of the receiver instead of overwritting it.
|
|
||||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
|
||||||
///
|
|
||||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
|
||||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
|
||||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
|
||||||
///
|
|
||||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
|
||||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
|
||||||
///
|
|
||||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
|
||||||
///
|
|
||||||
/// ```text
|
|
||||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
|
||||||
/// |h i j|
|
|
||||||
/// |k l m|
|
|
||||||
/// ```
|
|
||||||
/// where each element is a [VecZnxDft].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
|
|
||||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
|
||||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
|
||||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
|
||||||
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft<B>, a: &VecZnxDft<B>, b: &MatZnxDft<B>, buf: &mut [u8]);
|
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft] in place.
|
|
||||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
|
||||||
///
|
|
||||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
|
||||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
|
||||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
|
||||||
///
|
|
||||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
|
||||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
|
||||||
///
|
|
||||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
|
||||||
///
|
|
||||||
/// ```text
|
|
||||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
|
||||||
/// |h i j|
|
|
||||||
/// |k l m|
|
|
||||||
/// ```
|
|
||||||
/// where each element is a [VecZnxDft].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `b`: the input and output of the vector matrix product, as a [VecZnxDft].
|
|
||||||
/// * `a`: the right operand [MatZnxDft] of the vector matrix product.
|
|
||||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
|
||||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, buf: &mut [u8]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||||
fn new_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> MatZnxDft<FFT64> {
|
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDft<FFT64> {
|
||||||
MatZnxDft::<FFT64>::new(self, rows, cols, size)
|
MatZnxDft::<FFT64>::new(self, rows, cols_in, cols_out, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols: usize, size: usize) -> usize {
|
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||||
MatZnxDft::<FFT64>::bytes_of(self, rows, cols, size)
|
MatZnxDft::<FFT64>::bytes_of(self, rows, cols_in, cols_out, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_mat_znx_dft_from_bytes(&self, rows: usize, cols: usize, size: usize, bytes: Vec<u8>) -> MatZnxDft<FFT64> {
|
fn new_mat_znx_dft_from_bytes(
|
||||||
MatZnxDft::<FFT64>::from_bytes(self, rows, cols, size, bytes)
|
&self,
|
||||||
|
rows: usize,
|
||||||
|
cols_in: usize,
|
||||||
|
cols_out: usize,
|
||||||
|
size: usize,
|
||||||
|
bytes: Vec<u8>,
|
||||||
|
) -> MatZnxDft<FFT64> {
|
||||||
|
MatZnxDft::<FFT64>::from_bytes(self, rows, cols_in, cols_out, size, bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_mat_znx_dft_from_bytes_borrow(&self, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> MatZnxDft<FFT64> {
|
fn new_mat_znx_dft_from_bytes_borrow(
|
||||||
MatZnxDft::<FFT64>::from_bytes_borrow(self, rows, cols, size, bytes)
|
&self,
|
||||||
|
rows: usize,
|
||||||
|
cols_in: usize,
|
||||||
|
cols_out: usize,
|
||||||
|
size: usize,
|
||||||
|
bytes: &mut [u8],
|
||||||
|
) -> MatZnxDft<FFT64> {
|
||||||
|
MatZnxDft::<FFT64>::from_bytes_borrow(self, rows, cols_in, cols_out, size, bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize, size: usize) -> usize {
|
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
||||||
unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, (size * cols) as u64) as usize }
|
self.bytes_of_vec_znx_dft(cols_out, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_prepare_contiguous(&self, b: &mut MatZnxDft<FFT64>, a: &[i64], tmp_bytes: &mut [u8]) {
|
fn vmp_prepare_row(&self, b: &mut MatZnxDft<FFT64>, b_row: usize, b_col_in: usize, a: &VecZnx, tmp_bytes: &mut [u8]) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.len(), b.n() * b.poly_count());
|
assert_eq!(b.n(), self.n());
|
||||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
|
assert_eq!(a.n(), self.n());
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
assert_eq!(
|
||||||
|
a.cols(),
|
||||||
|
b.cols_out(),
|
||||||
|
"a.cols(): {} != b.cols_out(): {}",
|
||||||
|
a.cols(),
|
||||||
|
b.cols_out()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
b_row < b.rows(),
|
||||||
|
"b_row: {} >= b.rows(): {}",
|
||||||
|
b_row,
|
||||||
|
b.rows()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
b_col_in < b.cols_in(),
|
||||||
|
"b_col_in: {} >= b.cols_in(): {}",
|
||||||
|
b_col_in,
|
||||||
|
b.cols_in()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
b.size(),
|
||||||
|
a.size(),
|
||||||
|
"b.size(): {} != a.size(): {}",
|
||||||
|
b.size(),
|
||||||
|
a.size()
|
||||||
|
);
|
||||||
|
assert!(tmp_bytes.len() >= self.vmp_prepare_row_tmp_bytes(a.cols(), a.size()));
|
||||||
|
assert!(is_aligned(tmp_bytes.as_ptr()))
|
||||||
}
|
}
|
||||||
unsafe {
|
|
||||||
vmp::vmp_prepare_contiguous(
|
let cols_out: usize = a.cols();
|
||||||
self.ptr,
|
let a_size: usize = a.size();
|
||||||
b.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
|
||||||
a.as_ptr(),
|
let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size));
|
||||||
b.rows() as u64,
|
|
||||||
(b.size() * b.cols()) as u64,
|
let mut a_dft: VecZnxDft<FFT64> = self.new_vec_znx_dft_from_bytes_borrow(cols_out, a_size, tmp_bytes_a_dft);
|
||||||
tmp_bytes.as_mut_ptr(),
|
(0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i));
|
||||||
|
|
||||||
|
Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
||||||
|
self.bytes_of_vec_znx_dft(cols_out, size) + self.vec_znx_big_normalize_tmp_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vmp_extract_row(
|
||||||
|
&self,
|
||||||
|
log_base2k: usize,
|
||||||
|
b: &mut VecZnx,
|
||||||
|
a: &MatZnxDft<FFT64>,
|
||||||
|
a_row: usize,
|
||||||
|
a_col_in: usize,
|
||||||
|
tmp_bytes: &mut [u8],
|
||||||
|
) {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(b.n(), self.n());
|
||||||
|
assert_eq!(a.n(), self.n());
|
||||||
|
assert_eq!(
|
||||||
|
b.cols(),
|
||||||
|
a.cols_out(),
|
||||||
|
"b.cols(): {} != a.cols_out(): {}",
|
||||||
|
b.cols(),
|
||||||
|
a.cols_out()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
a_row < a.rows(),
|
||||||
|
"a_row: {} >= a.rows(): {}",
|
||||||
|
a_row,
|
||||||
|
a.rows()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
a_col_in < a.cols_in(),
|
||||||
|
"a_col_in: {} >= a.cols_in(): {}",
|
||||||
|
a_col_in,
|
||||||
|
a.cols_in()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
b.size(),
|
||||||
|
a.size(),
|
||||||
|
"b.size(): {} != a.size(): {}",
|
||||||
|
b.size(),
|
||||||
|
a.size()
|
||||||
|
);
|
||||||
|
assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size()));
|
||||||
|
assert!(is_aligned(tmp_bytes.as_ptr()))
|
||||||
|
}
|
||||||
|
|
||||||
|
let cols_out: usize = b.cols();
|
||||||
|
let size: usize = b.size();
|
||||||
|
|
||||||
|
let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size));
|
||||||
|
let mut b_dft: VecZnxDft<FFT64> = self.new_vec_znx_dft_from_bytes_borrow(cols_out, size, bytes_a_dft);
|
||||||
|
Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in);
|
||||||
|
let mut b_big: VecZnxBig<FFT64> = b_dft.alias_as_vec_znx_big();
|
||||||
|
(0..cols_out).for_each(|i| {
|
||||||
|
self.vec_znx_idft_tmp_a(&mut b_big, i, &mut b_dft, i);
|
||||||
|
self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, tmp_bytes);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<FFT64>, b_row: usize, b_col_in: usize, a: &VecZnxDft<FFT64>) {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(b.n(), self.n());
|
||||||
|
assert_eq!(a.n(), self.n());
|
||||||
|
assert_eq!(
|
||||||
|
a.cols(),
|
||||||
|
b.cols_out(),
|
||||||
|
"a.cols(): {} != b.cols_out(): {}",
|
||||||
|
a.cols(),
|
||||||
|
b.cols_out()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
b_row < b.rows(),
|
||||||
|
"b_row: {} >= b.rows(): {}",
|
||||||
|
b_row,
|
||||||
|
b.rows()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
b_col_in < b.cols_in(),
|
||||||
|
"b_col_in: {} >= b.cols_in(): {}",
|
||||||
|
b_col_in,
|
||||||
|
b.cols_in()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
b.size(),
|
||||||
|
a.size(),
|
||||||
|
"b.size(): {} != a.size(): {}",
|
||||||
|
b.size(),
|
||||||
|
a.size()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_prepare_row(&self, b: &mut MatZnxDft<FFT64>, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(a.len(), b.size() * self.n() * b.cols());
|
|
||||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols(), b.size()));
|
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
vmp::vmp_prepare_row(
|
|
||||||
self.ptr,
|
|
||||||
b.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
|
||||||
a.as_ptr(),
|
|
||||||
row_i as u64,
|
|
||||||
b.rows() as u64,
|
|
||||||
(b.size() * b.cols()) as u64,
|
|
||||||
tmp_bytes.as_mut_ptr(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_extract_row(&self, b: &mut VecZnxBig<FFT64>, a: &MatZnxDft<FFT64>, row_i: usize) {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(a.n(), b.n());
|
|
||||||
assert_eq!(a.size(), b.size());
|
|
||||||
assert_eq!(a.cols(), b.cols());
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
vmp::vmp_extract_row(
|
|
||||||
self.ptr,
|
|
||||||
b.as_mut_ptr() as *mut vec_znx_big_t,
|
|
||||||
a.as_ptr() as *const vmp::vmp_pmat_t,
|
|
||||||
row_i as u64,
|
|
||||||
a.rows() as u64,
|
|
||||||
(a.size() * a.cols()) as u64,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<FFT64>, a: &VecZnxDft<FFT64>, row_i: usize) {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(a.n(), b.n());
|
|
||||||
assert_eq!(a.size(), b.size());
|
|
||||||
}
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_prepare_row_dft(
|
vmp::vmp_prepare_row_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
b.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
b.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
||||||
a.as_ptr() as *const vec_znx_dft_t,
|
a.as_ptr() as *const vec_znx_dft_t,
|
||||||
row_i as u64,
|
(b_row * b.cols_in() + b_col_in) as u64,
|
||||||
b.rows() as u64,
|
(b.rows() * b.cols_in()) as u64,
|
||||||
b.size() as u64,
|
(b.size() * b.cols_out()) as u64,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, row_i: usize, a: &MatZnxDft<FFT64>) {
|
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, a_row: usize, a_col_in: usize) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), b.n());
|
assert_eq!(b.n(), self.n());
|
||||||
assert_eq!(a.size(), b.size());
|
assert_eq!(a.n(), self.n());
|
||||||
|
assert_eq!(
|
||||||
|
b.cols(),
|
||||||
|
a.cols_out(),
|
||||||
|
"b.cols(): {} != a.cols_out(): {}",
|
||||||
|
b.cols(),
|
||||||
|
a.cols_out()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
a_row < a.rows(),
|
||||||
|
"a_row: {} >= a.rows(): {}",
|
||||||
|
a_row,
|
||||||
|
a.rows()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
a_col_in < a.cols_in(),
|
||||||
|
"a_col_in: {} >= a.cols_in(): {}",
|
||||||
|
a_col_in,
|
||||||
|
a.cols_in()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
b.size(),
|
||||||
|
a.size(),
|
||||||
|
"b.size(): {} != a.size(): {}",
|
||||||
|
b.size(),
|
||||||
|
a.size()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_extract_row_dft(
|
vmp::vmp_extract_row_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
b.as_mut_ptr() as *mut vec_znx_dft_t,
|
b.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||||
a.as_ptr() as *const vmp::vmp_pmat_t,
|
a.as_ptr() as *const vmp::vmp_pmat_t,
|
||||||
row_i as u64,
|
(a_row * a.cols_in() + a_col_in) as u64,
|
||||||
a.rows() as u64,
|
(a.rows() * a.cols_in()) as u64,
|
||||||
a.size() as u64,
|
(a.size() * a.cols_out()) as u64,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize {
|
fn vmp_apply_dft_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
res_size: usize,
|
||||||
|
a_size: usize,
|
||||||
|
b_rows: usize,
|
||||||
|
b_cols_in: usize,
|
||||||
|
b_cols_out: usize,
|
||||||
|
b_size: usize,
|
||||||
|
) -> usize {
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft_tmp_bytes(
|
vmp::vmp_apply_dft_tmp_bytes(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res_size as u64,
|
res_size as u64,
|
||||||
a_size as u64,
|
a_size as u64,
|
||||||
b_rows as u64,
|
(b_rows * b_cols_in) as u64,
|
||||||
b_size as u64,
|
(b_size * b_cols_out) as u64,
|
||||||
) as usize
|
) as usize
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
fn vmp_apply_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
debug_assert!(
|
||||||
|
tmp_bytes.len()
|
||||||
|
>= self.vmp_apply_dft_tmp_bytes(
|
||||||
|
c.size(),
|
||||||
|
a.size(),
|
||||||
|
b.rows(),
|
||||||
|
b.cols_in(),
|
||||||
|
b.cols_out(),
|
||||||
|
b.size()
|
||||||
|
)
|
||||||
|
);
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
|
assert_eq!(c.n(), self.n());
|
||||||
|
assert_eq!(b.n(), self.n());
|
||||||
|
assert_eq!(a.n(), self.n());
|
||||||
|
assert_eq!(
|
||||||
|
c.cols(),
|
||||||
|
b.cols_out(),
|
||||||
|
"c.cols(): {} != b.cols_out: {}",
|
||||||
|
c.cols(),
|
||||||
|
b.cols_out()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
a.cols(),
|
||||||
|
b.cols_in(),
|
||||||
|
"a.cols(): {} != b.cols_in: {}",
|
||||||
|
a.cols(),
|
||||||
|
b.cols_in()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
tmp_bytes.len()
|
||||||
|
>= self.vmp_apply_dft_tmp_bytes(
|
||||||
|
c.size(),
|
||||||
|
a.size(),
|
||||||
|
b.rows(),
|
||||||
|
b.cols_in(),
|
||||||
|
b.cols_out(),
|
||||||
|
b.size()
|
||||||
|
)
|
||||||
|
);
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft(
|
vmp::vmp_apply_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||||
c.size() as u64,
|
(c.size() * c.cols()) as u64,
|
||||||
a.as_ptr(),
|
a.as_ptr(),
|
||||||
a.size() as u64,
|
(a.size() * a.cols()) as u64,
|
||||||
(a.n() * a.cols()) as u64,
|
a.n() as u64,
|
||||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||||
b.rows() as u64,
|
(b.rows() * b.cols_in()) as u64,
|
||||||
b.size() as u64,
|
(b.size() * b.cols_out()) as u64,
|
||||||
tmp_bytes.as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnx, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
&self,
|
||||||
#[cfg(debug_assertions)]
|
res_cols: usize,
|
||||||
{
|
res_size: usize,
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
a_size: usize,
|
||||||
}
|
a_cols: usize,
|
||||||
unsafe {
|
b_rows: usize,
|
||||||
vmp::vmp_apply_dft_add(
|
b_cols_in: usize,
|
||||||
self.ptr,
|
b_cols_out: usize,
|
||||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
b_size: usize,
|
||||||
c.size() as u64,
|
) -> usize {
|
||||||
a.as_ptr(),
|
|
||||||
a.size() as u64,
|
|
||||||
(a.n() * a.size()) as u64,
|
|
||||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
|
||||||
b.rows() as u64,
|
|
||||||
b.size() as u64,
|
|
||||||
tmp_bytes.as_mut_ptr(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize {
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res_size as u64,
|
(res_size * res_cols) as u64,
|
||||||
a_size as u64,
|
(a_size * a_cols) as u64,
|
||||||
gct_rows as u64,
|
(b_rows * b_cols_in) as u64,
|
||||||
gct_size as u64,
|
(b_size * b_cols_out) as u64,
|
||||||
) as usize
|
) as usize
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft<FFT64>, a: &VecZnxDft<FFT64>, b: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
||||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
|
assert_eq!(c.n(), self.n());
|
||||||
|
assert_eq!(b.n(), self.n());
|
||||||
|
assert_eq!(a.n(), self.n());
|
||||||
|
assert_eq!(
|
||||||
|
c.cols(),
|
||||||
|
b.cols_out(),
|
||||||
|
"c.cols(): {} != b.cols_out: {}",
|
||||||
|
c.cols(),
|
||||||
|
b.cols_out()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
a.cols(),
|
||||||
|
b.cols_in(),
|
||||||
|
"a.cols(): {} != b.cols_in: {}",
|
||||||
|
a.cols(),
|
||||||
|
b.cols_in()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
tmp_bytes.len()
|
||||||
|
>= self.vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
|
c.cols(),
|
||||||
|
c.size(),
|
||||||
|
a.cols(),
|
||||||
|
a.size(),
|
||||||
|
b.rows(),
|
||||||
|
b.cols_in(),
|
||||||
|
b.cols_out(),
|
||||||
|
b.size()
|
||||||
|
)
|
||||||
|
);
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft_to_dft(
|
vmp::vmp_apply_dft_to_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||||
c.size() as u64,
|
c.poly_count() as u64,
|
||||||
a.as_ptr() as *const vec_znx_dft_t,
|
a.as_ptr() as *const vec_znx_dft_t,
|
||||||
a.size() as u64,
|
a.poly_count() as u64,
|
||||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||||
b.rows() as u64,
|
b.rows() as u64,
|
||||||
b.size() as u64,
|
(b.size() * b.cols()) as u64,
|
||||||
tmp_bytes.as_mut_ptr(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft_add(
|
|
||||||
&self,
|
|
||||||
c: &mut VecZnxDft<FFT64>,
|
|
||||||
a: &VecZnxDft<FFT64>,
|
|
||||||
b: &MatZnxDft<FFT64>,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.size(), a.size(), b.rows(), b.size()));
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
vmp::vmp_apply_dft_to_dft_add(
|
|
||||||
self.ptr,
|
|
||||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
|
||||||
c.size() as u64,
|
|
||||||
a.as_ptr() as *const vec_znx_dft_t,
|
|
||||||
a.size() as u64,
|
|
||||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
|
||||||
b.rows() as u64,
|
|
||||||
b.size() as u64,
|
|
||||||
tmp_bytes.as_mut_ptr(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, tmp_bytes: &mut [u8]) {
|
|
||||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.size(), b.size(), a.rows(), a.size()));
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
vmp::vmp_apply_dft_to_dft(
|
|
||||||
self.ptr,
|
|
||||||
b.as_mut_ptr() as *mut vec_znx_dft_t,
|
|
||||||
b.size() as u64,
|
|
||||||
b.as_ptr() as *mut vec_znx_dft_t,
|
|
||||||
b.size() as u64,
|
|
||||||
a.as_ptr() as *const vmp::vmp_pmat_t,
|
|
||||||
a.rows() as u64,
|
|
||||||
a.size() as u64,
|
|
||||||
tmp_bytes.as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -497,38 +576,52 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vmp_prepare_row_dft() {
|
fn vmp_prepare_row_dft() {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
let module: Module<FFT64> = Module::<FFT64>::new(16);
|
||||||
let vpmat_rows: usize = 4;
|
|
||||||
let vpmat_size: usize = 5;
|
|
||||||
let log_base2k: usize = 8;
|
let log_base2k: usize = 8;
|
||||||
let mut a: VecZnx = module.new_vec_znx(1, vpmat_size);
|
let mat_rows: usize = 4;
|
||||||
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, vpmat_size);
|
let mat_cols_in: usize = 2;
|
||||||
let mut a_big: VecZnxBig<FFT64> = module.new_vec_znx_big(1, vpmat_size);
|
let mat_cols_out: usize = 2;
|
||||||
let mut b_big: VecZnxBig<FFT64> = module.new_vec_znx_big(1, vpmat_size);
|
let mat_size: usize = 5;
|
||||||
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(1, vpmat_size);
|
let mut a: VecZnx = module.new_vec_znx(mat_cols_out, mat_size);
|
||||||
let mut vmpmat_0: MatZnxDft<FFT64> = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size);
|
let mut b: VecZnx = module.new_vec_znx(mat_cols_out, mat_size);
|
||||||
let mut vmpmat_1: MatZnxDft<FFT64> = module.new_mat_znx_dft(vpmat_rows, 1, vpmat_size);
|
let mut a_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||||
|
let mut a_big: VecZnxBig<FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||||
|
let mut b_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||||
|
let mut vmpmat_0: MatZnxDft<FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||||
|
let mut vmpmat_1: MatZnxDft<FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||||
|
|
||||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, 1, vpmat_size));
|
let mut tmp_bytes: Vec<u8> =
|
||||||
|
alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes());
|
||||||
|
|
||||||
for row_i in 0..vpmat_rows {
|
for col_in in 0..mat_cols_in {
|
||||||
let mut source: Source = Source::new([0u8; 32]);
|
for row_i in 0..mat_rows {
|
||||||
module.fill_uniform(log_base2k, &mut a, 0, vpmat_size, &mut source);
|
let mut source: Source = Source::new([0u8; 32]);
|
||||||
module.vec_znx_dft(&mut a_dft, 0, &a, 0);
|
|
||||||
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
|
|
||||||
|
|
||||||
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
|
(0..mat_cols_out).for_each(|col_out| {
|
||||||
module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i);
|
module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source);
|
||||||
assert_eq!(vmpmat_0.raw(), vmpmat_1.raw());
|
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
|
||||||
|
});
|
||||||
|
|
||||||
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft)
|
module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut tmp_bytes);
|
||||||
module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0);
|
|
||||||
assert_eq!(a_dft.raw(), b_dft.raw());
|
|
||||||
|
|
||||||
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
|
||||||
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
|
module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft);
|
||||||
module.vec_znx_idft(&mut a_big, 0, &a_dft, 0, &mut tmp_bytes);
|
assert_eq!(vmpmat_0.raw(), vmpmat_1.raw());
|
||||||
assert_eq!(a_big.raw(), b_big.raw());
|
|
||||||
|
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft)
|
||||||
|
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i, col_in);
|
||||||
|
assert_eq!(a_dft.raw(), b_dft.raw());
|
||||||
|
|
||||||
|
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
||||||
|
module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut tmp_bytes);
|
||||||
|
|
||||||
|
(0..mat_cols_out).for_each(|col_out| {
|
||||||
|
module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes);
|
||||||
|
module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut tmp_bytes);
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(a.raw(), b.raw());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module.free();
|
module.free();
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ impl<B: Backend> ZnxAlloc<B> for ScalarZnxDft<B> {
|
|||||||
type Scalar = u8;
|
type Scalar = u8;
|
||||||
|
|
||||||
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
|
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
|
||||||
|
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size));
|
||||||
Self {
|
Self {
|
||||||
inner: ZnxBase::from_bytes_borrow(
|
inner: ZnxBase::from_bytes_borrow(
|
||||||
module.n(),
|
module.n(),
|
||||||
@@ -61,6 +62,6 @@ impl ZnxLayout for ScalarZnxDft<FFT64> {
|
|||||||
|
|
||||||
impl ZnxSliceSize for ScalarZnxDft<FFT64> {
|
impl ZnxSliceSize for ScalarZnxDft<FFT64> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.n()
|
self.n() * self.cols()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use crate::Module;
|
|||||||
use crate::assert_alignement;
|
use crate::assert_alignement;
|
||||||
use crate::cast_mut;
|
use crate::cast_mut;
|
||||||
use crate::ffi::znx;
|
use crate::ffi::znx;
|
||||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize, switch_degree};
|
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxRsh, ZnxSliceSize, ZnxZero, switch_degree};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
|
||||||
pub const VEC_ZNX_ROWS: usize = 1;
|
pub const VEC_ZNX_ROWS: usize = 1;
|
||||||
@@ -44,7 +44,9 @@ impl ZnxLayout for VecZnx {
|
|||||||
type Scalar = i64;
|
type Scalar = i64;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ZnxBasics for VecZnx {}
|
impl ZnxZero for VecZnx {}
|
||||||
|
|
||||||
|
impl ZnxRsh for VecZnx {}
|
||||||
|
|
||||||
impl<B: Backend> ZnxAlloc<B> for VecZnx {
|
impl<B: Backend> ZnxAlloc<B> for VecZnx {
|
||||||
type Scalar = i64;
|
type Scalar = i64;
|
||||||
@@ -84,7 +86,7 @@ impl VecZnx {
|
|||||||
///
|
///
|
||||||
/// * `log_base2k`: the base two logarithm of the coefficients decomposition.
|
/// * `log_base2k`: the base two logarithm of the coefficients decomposition.
|
||||||
/// * `k`: the number of bits of precision to drop.
|
/// * `k`: the number of bits of precision to drop.
|
||||||
pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize) {
|
pub fn trunc_pow2(&mut self, log_base2k: usize, k: usize, col: usize) {
|
||||||
if k == 0 {
|
if k == 0 {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -101,7 +103,7 @@ impl VecZnx {
|
|||||||
|
|
||||||
if k_rem != 0 {
|
if k_rem != 0 {
|
||||||
let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem;
|
let mask: i64 = ((1 << (log_base2k - k_rem - 1)) - 1) << k_rem;
|
||||||
self.at_limb_mut(self.size() - 1)
|
self.at_mut(col, self.size() - 1)
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.for_each(|x: &mut i64| *x &= mask)
|
.for_each(|x: &mut i64| *x &= mask)
|
||||||
}
|
}
|
||||||
@@ -111,8 +113,8 @@ impl VecZnx {
|
|||||||
copy_vec_znx_from(self, a);
|
copy_vec_znx_from(self, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn normalize(&mut self, log_base2k: usize, carry: &mut [u8]) {
|
pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||||
normalize(log_base2k, self, carry)
|
normalize(log_base2k, self, col, carry)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) {
|
pub fn switch_degree(&self, col: usize, a: &mut Self, col_a: usize) {
|
||||||
@@ -120,26 +122,25 @@ impl VecZnx {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prints the first `n` coefficients of each limb
|
// Prints the first `n` coefficients of each limb
|
||||||
pub fn print(&self, n: usize) {
|
pub fn print(&self, n: usize, col: usize) {
|
||||||
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]))
|
(0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_tmp_bytes(n: usize, size: usize) -> usize {
|
fn normalize_tmp_bytes(n: usize) -> usize {
|
||||||
n * size * std::mem::size_of::<i64>()
|
n * std::mem::size_of::<i64>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
|
fn normalize(log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||||
let n: usize = a.n();
|
let n: usize = a.n();
|
||||||
let cols: usize = a.cols();
|
|
||||||
|
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
tmp_bytes.len() >= normalize_tmp_bytes(n, cols),
|
tmp_bytes.len() >= normalize_tmp_bytes(n),
|
||||||
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({}, {})",
|
"invalid tmp_bytes: tmp_bytes.len()={} < normalize_tmp_bytes({})",
|
||||||
tmp_bytes.len(),
|
tmp_bytes.len(),
|
||||||
n,
|
n,
|
||||||
cols,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_alignement(tmp_bytes.as_ptr())
|
assert_alignement(tmp_bytes.as_ptr())
|
||||||
@@ -151,11 +152,11 @@ fn normalize(log_base2k: usize, a: &mut VecZnx, tmp_bytes: &mut [u8]) {
|
|||||||
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
|
znx::znx_zero_i64_ref(n as u64, carry_i64.as_mut_ptr());
|
||||||
(0..a.size()).rev().for_each(|i| {
|
(0..a.size()).rev().for_each(|i| {
|
||||||
znx::znx_normalize(
|
znx::znx_normalize(
|
||||||
(n * cols) as u64,
|
n as u64,
|
||||||
log_base2k as u64,
|
log_base2k as u64,
|
||||||
a.at_mut_ptr(0, i),
|
a.at_mut_ptr(a_col, i),
|
||||||
carry_i64.as_mut_ptr(),
|
carry_i64.as_mut_ptr(),
|
||||||
a.at_mut_ptr(0, i),
|
a.at_mut_ptr(a_col, i),
|
||||||
carry_i64.as_mut_ptr(),
|
carry_i64.as_mut_ptr(),
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::ffi::vec_znx_big;
|
use crate::ffi::vec_znx_big;
|
||||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxBasics, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero};
|
||||||
use crate::{Backend, FFT64, Module, NTT120};
|
use crate::{Backend, FFT64, Module, NTT120};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
@@ -26,6 +26,7 @@ impl<B: Backend> ZnxAlloc<B> for VecZnxBig<B> {
|
|||||||
type Scalar = u8;
|
type Scalar = u8;
|
||||||
|
|
||||||
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||||
|
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
|
||||||
VecZnxBig {
|
VecZnxBig {
|
||||||
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes),
|
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_BIG_ROWS, cols, size, bytes),
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
@@ -50,24 +51,24 @@ impl ZnxLayout for VecZnxBig<NTT120> {
|
|||||||
type Scalar = i128;
|
type Scalar = i128;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ZnxBasics for VecZnxBig<FFT64> {}
|
impl ZnxZero for VecZnxBig<FFT64> {}
|
||||||
|
|
||||||
impl ZnxSliceSize for VecZnxBig<FFT64> {
|
impl ZnxSliceSize for VecZnxBig<FFT64> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.n()
|
self.n() * self.cols()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ZnxSliceSize for VecZnxBig<NTT120> {
|
impl ZnxSliceSize for VecZnxBig<NTT120> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.n() * 4
|
self.n() * 4 * self.cols()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ZnxBasics for VecZnxBig<NTT120> {}
|
impl ZnxZero for VecZnxBig<NTT120> {}
|
||||||
|
|
||||||
impl VecZnxBig<FFT64> {
|
impl VecZnxBig<FFT64> {
|
||||||
pub fn print(&self, n: usize) {
|
pub fn print(&self, n: usize, col: usize) {
|
||||||
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
|
use crate::ffi::vec_znx;
|
||||||
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement};
|
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxOps, assert_alignement};
|
||||||
|
|
||||||
@@ -171,14 +171,17 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|||||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_big::vec_znx_big_add(
|
vec_znx::vec_znx_add(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t,
|
res.sl() as u64,
|
||||||
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t,
|
a.sl() as u64,
|
||||||
|
b.at_ptr(b_col, 0),
|
||||||
b.size() as u64,
|
b.size() as u64,
|
||||||
|
b.sl() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -207,14 +210,17 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|||||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_big::vec_znx_big_sub(
|
vec_znx::vec_znx_sub(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_big_t,
|
res.sl() as u64,
|
||||||
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
b.at_ptr(b_col * res.size(), 0) as *const vec_znx_big_t,
|
a.sl() as u64,
|
||||||
|
b.at_ptr(b_col, 0),
|
||||||
b.size() as u64,
|
b.size() as u64,
|
||||||
|
b.sl() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -250,12 +256,14 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|||||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_big::vec_znx_big_sub_small_b(
|
vec_znx::vec_znx_sub(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
|
res.sl() as u64,
|
||||||
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
b.at_ptr(b_col, 0),
|
b.at_ptr(b_col, 0),
|
||||||
b.size() as u64,
|
b.size() as u64,
|
||||||
b.sl() as u64,
|
b.sl() as u64,
|
||||||
@@ -287,15 +295,17 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|||||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_big::vec_znx_big_sub_small_a(
|
vec_znx::vec_znx_sub(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
|
res.sl() as u64,
|
||||||
a.at_ptr(a_col, 0),
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
a.sl() as u64,
|
a.sl() as u64,
|
||||||
b.at_ptr(b_col * b.size(), 0) as *const vec_znx_big_t,
|
b.at_ptr(b_col, 0),
|
||||||
b.size() as u64,
|
b.size() as u64,
|
||||||
|
b.sl() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -324,12 +334,14 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|||||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_big::vec_znx_big_add_small(
|
vec_znx::vec_znx_add(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
|
res.sl() as u64,
|
||||||
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
b.at_ptr(b_col, 0),
|
b.at_ptr(b_col, 0),
|
||||||
b.size() as u64,
|
b.size() as u64,
|
||||||
b.sl() as u64,
|
b.sl() as u64,
|
||||||
@@ -365,14 +377,15 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|||||||
assert_alignement(tmp_bytes.as_ptr());
|
assert_alignement(tmp_bytes.as_ptr());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_big::vec_znx_big_normalize_base2k(
|
vec_znx::vec_znx_normalize_base2k(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
log_base2k as u64,
|
log_base2k as u64,
|
||||||
res.at_mut_ptr(res_col, 0),
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
res.sl() as u64,
|
res.sl() as u64,
|
||||||
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
tmp_bytes.as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -385,13 +398,15 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|||||||
assert_eq!(res.n(), self.n());
|
assert_eq!(res.n(), self.n());
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_big::vec_znx_big_automorphism(
|
vec_znx::vec_znx_automorphism(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
k,
|
k,
|
||||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big_t,
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
a.at_ptr(a_col * a.size(), 0) as *const vec_znx_big_t,
|
res.sl() as u64,
|
||||||
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::ffi::vec_znx_dft;
|
use crate::ffi::vec_znx_dft;
|
||||||
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
use crate::znx_base::{GetZnxBase, ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize, ZnxZero};
|
||||||
use crate::{Backend, FFT64, Module, VecZnxBig};
|
use crate::{Backend, FFT64, Module, VecZnxBig};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
@@ -26,6 +26,7 @@ impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> {
|
|||||||
type Scalar = u8;
|
type Scalar = u8;
|
||||||
|
|
||||||
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||||
|
debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
|
||||||
Self {
|
Self {
|
||||||
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
|
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
@@ -46,6 +47,8 @@ impl ZnxLayout for VecZnxDft<FFT64> {
|
|||||||
type Scalar = f64;
|
type Scalar = f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ZnxZero for VecZnxDft<FFT64> {}
|
||||||
|
|
||||||
impl ZnxSliceSize for VecZnxDft<FFT64> {
|
impl ZnxSliceSize for VecZnxDft<FFT64> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.n()
|
self.n()
|
||||||
@@ -53,8 +56,8 @@ impl ZnxSliceSize for VecZnxDft<FFT64> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl VecZnxDft<FFT64> {
|
impl VecZnxDft<FFT64> {
|
||||||
pub fn print(&self, n: usize) {
|
pub fn print(&self, n: usize, col: usize) {
|
||||||
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at_limb(i)[..n]));
|
(0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,6 +66,10 @@ impl<B: Backend> VecZnxDft<B> {
|
|||||||
/// The returned [VecZnxBig] shares the backing array
|
/// The returned [VecZnxBig] shares the backing array
|
||||||
/// with the original [VecZnxDft].
|
/// with the original [VecZnxDft].
|
||||||
pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> {
|
pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> {
|
||||||
|
assert!(
|
||||||
|
self.data().len() == 0,
|
||||||
|
"cannot alias VecZnxDft into VecZnxBig if it owns the data"
|
||||||
|
);
|
||||||
VecZnxBig::<B> {
|
VecZnxBig::<B> {
|
||||||
inner: ZnxBase {
|
inner: ZnxBase {
|
||||||
data: Vec::new(),
|
data: Vec::new(),
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ use crate::znx_base::ZnxAlloc;
|
|||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::znx_base::ZnxLayout;
|
use crate::znx_base::ZnxLayout;
|
||||||
use crate::znx_base::ZnxSliceSize;
|
use crate::znx_base::ZnxSliceSize;
|
||||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, assert_alignement};
|
use crate::{Backend, FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxZero, assert_alignement};
|
||||||
|
use std::cmp::min;
|
||||||
|
|
||||||
pub trait VecZnxDftOps<B: Backend> {
|
pub trait VecZnxDftOps<B: Backend> {
|
||||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||||
@@ -77,19 +78,21 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &mut VecZnxDft<FFT64>, a_col: usize) {
|
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<FFT64>, res_col: usize, a: &mut VecZnxDft<FFT64>, a_col: usize) {
|
||||||
#[cfg(debug_assertions)]
|
let min_size: usize = min(res.size(), a.size());
|
||||||
{
|
|
||||||
assert_eq!(res.poly_count(), a.poly_count());
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
(0..min_size).for_each(|j| {
|
||||||
self.ptr,
|
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t,
|
self.ptr,
|
||||||
res.size() as u64,
|
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||||
a.at_ptr(a_col * a.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t,
|
1 as u64,
|
||||||
a.size() as u64,
|
a.at_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||||
)
|
1 as u64,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
(min_size..res.size()).for_each(|j| {
|
||||||
|
res.zero_at(res_col, j);
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,15 +105,22 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
/// # Panics
|
/// # Panics
|
||||||
/// If b.cols < a_cols
|
/// If b.cols < a_cols
|
||||||
fn vec_znx_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
|
fn vec_znx_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||||
|
let min_size: usize = min(res.size(), a.size());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_dft::vec_znx_dft(
|
(0..min_size).for_each(|j| {
|
||||||
self.ptr,
|
vec_znx_dft::vec_znx_dft(
|
||||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_dft::vec_znx_dft_t,
|
self.ptr,
|
||||||
res.size() as u64,
|
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||||
a.at_ptr(a_col, 0),
|
1 as u64,
|
||||||
a.size() as u64,
|
a.at_ptr(a_col, j),
|
||||||
a.sl() as u64,
|
1 as u64,
|
||||||
)
|
a.sl() as u64,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
(min_size..res.size()).for_each(|j| {
|
||||||
|
res.zero_at(res_col, j);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,15 +136,23 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
);
|
);
|
||||||
assert_alignement(tmp_bytes.as_ptr())
|
assert_alignement(tmp_bytes.as_ptr())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let min_size: usize = min(res.size(), a.size());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx_dft::vec_znx_idft(
|
(0..min_size).for_each(|j| {
|
||||||
self.ptr,
|
vec_znx_dft::vec_znx_idft(
|
||||||
res.at_mut_ptr(res_col * res.size(), 0) as *mut vec_znx_big::vec_znx_big_t,
|
self.ptr,
|
||||||
res.size() as u64,
|
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||||
a.at_ptr(a_col * res.size(), 0) as *const vec_znx_dft::vec_znx_dft_t,
|
1 as u64,
|
||||||
a.size() as u64,
|
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||||
tmp_bytes.as_mut_ptr(),
|
1 as u64,
|
||||||
)
|
tmp_bytes.as_mut_ptr(),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
(min_size..res.size()).for_each(|j| {
|
||||||
|
res.zero_at(res_col, j);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,33 @@ pub struct ZnxBase {
|
|||||||
pub ptr: *mut u8,
|
pub ptr: *mut u8,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ZnxBase {
|
||||||
|
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
||||||
|
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
|
||||||
|
res.data = bytes;
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(n & (n - 1), 0, "n must be a power of two");
|
||||||
|
assert!(n > 0, "n must be greater than 0");
|
||||||
|
assert!(rows > 0, "rows must be greater than 0");
|
||||||
|
assert!(cols > 0, "cols must be greater than 0");
|
||||||
|
assert!(size > 0, "size must be greater than 0");
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
n: n,
|
||||||
|
rows: rows,
|
||||||
|
cols: cols,
|
||||||
|
size: size,
|
||||||
|
data: Vec::new(),
|
||||||
|
ptr: bytes.as_mut_ptr(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait GetZnxBase {
|
pub trait GetZnxBase {
|
||||||
fn znx(&self) -> &ZnxBase;
|
fn znx(&self) -> &ZnxBase;
|
||||||
fn znx_mut(&mut self) -> &mut ZnxBase;
|
fn znx_mut(&mut self) -> &mut ZnxBase;
|
||||||
@@ -52,10 +79,12 @@ pub trait ZnxInfos: GetZnxBase {
|
|||||||
self.znx().size
|
self.znx().size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the underlying raw bytes array.
|
||||||
fn data(&self) -> &[u8] {
|
fn data(&self) -> &[u8] {
|
||||||
&self.znx().data
|
&self.znx().data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a pointer to the underlying raw bytes array.
|
||||||
fn ptr(&self) -> *mut u8 {
|
fn ptr(&self) -> *mut u8 {
|
||||||
self.znx().ptr
|
self.znx().ptr
|
||||||
}
|
}
|
||||||
@@ -72,33 +101,6 @@ pub trait ZnxSliceSize {
|
|||||||
fn sl(&self) -> usize;
|
fn sl(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ZnxBase {
|
|
||||||
pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec<u8>) -> Self {
|
|
||||||
let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes);
|
|
||||||
res.data = bytes;
|
|
||||||
res
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(n & (n - 1), 0, "n must be a power of two");
|
|
||||||
assert!(n > 0, "n must be greater than 0");
|
|
||||||
assert!(rows > 0, "rows must be greater than 0");
|
|
||||||
assert!(cols > 0, "cols must be greater than 0");
|
|
||||||
assert!(size > 0, "size must be greater than 0");
|
|
||||||
}
|
|
||||||
Self {
|
|
||||||
n: n,
|
|
||||||
rows: rows,
|
|
||||||
cols: cols,
|
|
||||||
size: size,
|
|
||||||
data: Vec::new(),
|
|
||||||
ptr: bytes.as_mut_ptr(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ZnxAlloc<B: Backend>
|
pub trait ZnxAlloc<B: Backend>
|
||||||
where
|
where
|
||||||
Self: Sized + ZnxInfos,
|
Self: Sized + ZnxInfos,
|
||||||
@@ -148,25 +150,25 @@ pub trait ZnxLayout: ZnxInfos {
|
|||||||
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
|
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a non-mutable pointer starting at the (i, j)-th small polynomial.
|
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||||
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert!(i < self.cols());
|
assert!(i < self.cols());
|
||||||
assert!(j < self.size());
|
assert!(j < self.size());
|
||||||
}
|
}
|
||||||
let offset = self.n() * (j * self.cols() + i);
|
let offset: usize = self.n() * (j * self.cols() + i);
|
||||||
unsafe { self.as_ptr().add(offset) }
|
unsafe { self.as_ptr().add(offset) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a mutable pointer starting at the (i, j)-th small polynomial.
|
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||||
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert!(i < self.cols());
|
assert!(i < self.cols());
|
||||||
assert!(j < self.size());
|
assert!(j < self.size());
|
||||||
}
|
}
|
||||||
let offset = self.n() * (j * self.cols() + i);
|
let offset: usize = self.n() * (j * self.cols() + i);
|
||||||
unsafe { self.as_mut_ptr().add(offset) }
|
unsafe { self.as_mut_ptr().add(offset) }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,16 +181,6 @@ pub trait ZnxLayout: ZnxInfos {
|
|||||||
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
||||||
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
|
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns non-mutable reference to the i-th limb.
|
|
||||||
fn at_limb(&self, j: usize) -> &[Self::Scalar] {
|
|
||||||
unsafe { std::slice::from_raw_parts(self.at_ptr(0, j), self.n() * self.cols()) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns mutable reference to the i-th limb.
|
|
||||||
fn at_limb_mut(&mut self, j: usize) -> &mut [Self::Scalar] {
|
|
||||||
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(0, j), self.n() * self.cols()) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
@@ -221,14 +213,17 @@ impl IntegerType for i128 {
|
|||||||
const BITS: u32 = 128;
|
const BITS: u32 = 128;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ZnxBasics: ZnxLayout
|
pub trait ZnxZero: ZnxLayout
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
Self::Scalar: IntegerType,
|
|
||||||
{
|
{
|
||||||
fn zero(&mut self) {
|
fn zero(&mut self) {
|
||||||
unsafe {
|
unsafe {
|
||||||
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * size_of::<Self::Scalar>());
|
std::ptr::write_bytes(
|
||||||
|
self.as_mut_ptr(),
|
||||||
|
0,
|
||||||
|
self.n() * size_of::<Self::Scalar>() * self.poly_count(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,13 +236,19 @@ where
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn rsh(&mut self, log_base2k: usize, k: usize, carry: &mut [u8]) {
|
pub trait ZnxRsh: ZnxLayout + ZnxZero
|
||||||
rsh(log_base2k, self, k, carry)
|
where
|
||||||
|
Self: Sized,
|
||||||
|
Self::Scalar: IntegerType,
|
||||||
|
{
|
||||||
|
fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||||
|
rsh(k, log_base2k, self, col, carry)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rsh<V: ZnxBasics>(log_base2k: usize, a: &mut V, k: usize, tmp_bytes: &mut [u8])
|
pub fn rsh<V: ZnxRsh + ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8])
|
||||||
where
|
where
|
||||||
V::Scalar: IntegerType,
|
V::Scalar: IntegerType,
|
||||||
{
|
{
|
||||||
@@ -258,7 +259,7 @@ where
|
|||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert!(
|
assert!(
|
||||||
tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n, cols),
|
tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n),
|
||||||
"invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
|
"invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
|
||||||
tmp_bytes.len() / size_of::<V::Scalar>(),
|
tmp_bytes.len() / size_of::<V::Scalar>(),
|
||||||
n,
|
n,
|
||||||
@@ -291,7 +292,7 @@ where
|
|||||||
let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
|
let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap();
|
||||||
|
|
||||||
(steps..size).for_each(|i| {
|
(steps..size).for_each(|i| {
|
||||||
izip!(carry.iter_mut(), a.at_limb_mut(i).iter_mut()).for_each(|(ci, xi)| {
|
izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| {
|
||||||
*xi += *ci << log_base2k_t;
|
*xi += *ci << log_base2k_t;
|
||||||
*ci = get_base_k_carry(*xi, shift);
|
*ci = get_base_k_carry(*xi, shift);
|
||||||
*xi = (*xi - *ci) >> k_rem_t;
|
*xi = (*xi - *ci) >> k_rem_t;
|
||||||
@@ -305,11 +306,11 @@ fn get_base_k_carry<T: IntegerType>(x: T, shift: T) -> T {
|
|||||||
(x << shift) >> shift
|
(x << shift) >> shift
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rsh_tmp_bytes<T: IntegerType>(n: usize, cols: usize) -> usize {
|
pub fn rsh_tmp_bytes<T: IntegerType>(n: usize) -> usize {
|
||||||
n * cols * std::mem::size_of::<T>()
|
n * std::mem::size_of::<T>()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn switch_degree<T: ZnxLayout + ZnxBasics>(b: &mut T, col_b: usize, a: &T, col_a: usize)
|
pub fn switch_degree<T: ZnxLayout + ZnxZero>(b: &mut T, col_b: usize, a: &T, col_a: usize)
|
||||||
where
|
where
|
||||||
<T as ZnxLayout>::Scalar: IntegerType,
|
<T as ZnxLayout>::Scalar: IntegerType,
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user