mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
rework as discussed
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Encoding, FFT64, Module, Sampling, ScalarAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnxAlloc, VecZnxBigOps,
|
Encoding, FFT64, Module, Sampling, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned,
|
||||||
VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos,
|
VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps,
|
||||||
|
VecZnxOps, ZnxInfos,
|
||||||
};
|
};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
@@ -13,24 +14,23 @@ fn main() {
|
|||||||
let log_scale: usize = msg_size * log_base2k - 5;
|
let log_scale: usize = msg_size * log_base2k - 5;
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||||
|
|
||||||
let mut scratch =
|
let mut scratch: ScratchOwned = ScratchOwned::new(module.vec_znx_big_normalize_tmp_bytes());
|
||||||
ScratchOwned::new((2 * module.bytes_of_vec_znx_dft(1, ct_size)) + 2 * module.vec_znx_big_normalize_tmp_bytes());
|
|
||||||
|
|
||||||
let seed: [u8; 32] = [0; 32];
|
let seed: [u8; 32] = [0; 32];
|
||||||
let mut source: Source = Source::new(seed);
|
let mut source: Source = Source::new(seed);
|
||||||
|
|
||||||
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
||||||
let mut s = module.new_scalar(1);
|
let mut s: Scalar<Vec<u8>> = module.new_scalar(1);
|
||||||
s.fill_ternary_prob(0, 0.5, &mut source);
|
s.fill_ternary_prob(0, 0.5, &mut source);
|
||||||
|
|
||||||
// Buffer to store s in the DFT domain
|
// Buffer to store s in the DFT domain
|
||||||
let mut s_dft = module.new_scalar_znx_dft(s.cols());
|
let mut s_dft: ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(s.cols());
|
||||||
|
|
||||||
// s_dft <- DFT(s)
|
// s_dft <- DFT(s)
|
||||||
module.svp_prepare(&mut s_dft, 0, &s, 0);
|
module.svp_prepare(&mut s_dft, 0, &s, 0);
|
||||||
|
|
||||||
// Allocates a VecZnx with two columns: ct=(0, 0)
|
// Allocates a VecZnx with two columns: ct=(0, 0)
|
||||||
let mut ct = module.new_vec_znx(
|
let mut ct: VecZnx<Vec<u8>> = module.new_vec_znx(
|
||||||
2, // Number of columns
|
2, // Number of columns
|
||||||
ct_size, // Number of small poly per column
|
ct_size, // Number of small poly per column
|
||||||
);
|
);
|
||||||
@@ -38,12 +38,10 @@ fn main() {
|
|||||||
// Fill the second column with random values: ct = (0, a)
|
// Fill the second column with random values: ct = (0, a)
|
||||||
module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source);
|
module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source);
|
||||||
|
|
||||||
// Scratch space for DFT values
|
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_size);
|
||||||
let scratch = scratch.borrow();
|
|
||||||
let (mut buf_dft, scratch) = scratch.tmp_vec_znx_dft(&module, 1, ct_size);
|
|
||||||
|
|
||||||
// Applies DFT(ct[1]) * DFT(s)
|
// Applies DFT(ct[1]) * DFT(s)
|
||||||
module.svp_apply_dft(
|
module.svp_apply(
|
||||||
&mut buf_dft, // DFT(ct[1] * s)
|
&mut buf_dft, // DFT(ct[1] * s)
|
||||||
0, // Selects the first column of res
|
0, // Selects the first column of res
|
||||||
&s_dft, // DFT(s)
|
&s_dft, // DFT(s)
|
||||||
@@ -53,11 +51,10 @@ fn main() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
|
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
|
||||||
let (mut buf_big, scratch) = scratch.tmp_vec_znx_big(&module, 1, ct_size);
|
|
||||||
|
|
||||||
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
|
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
|
||||||
// Note: Since `vec_znx_idft_tmp_a` takes no argument for generic `Data` a full qualified path seems necessary
|
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_size);
|
||||||
<Module<_> as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0);
|
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||||
|
|
||||||
// Creates a plaintext: VecZnx with 1 column
|
// Creates a plaintext: VecZnx with 1 column
|
||||||
let mut m = module.new_vec_znx(
|
let mut m = module.new_vec_znx(
|
||||||
@@ -68,8 +65,7 @@ fn main() {
|
|||||||
want.iter_mut()
|
want.iter_mut()
|
||||||
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
||||||
m.encode_vec_i64(0, log_base2k, log_scale, &want, 4);
|
m.encode_vec_i64(0, log_base2k, log_scale, &want, 4);
|
||||||
let (tmp_bytes_norm, scratch) = scratch.tmp_scalar_slice(n * std::mem::size_of::<i64>());
|
module.vec_znx_normalize_inplace(log_base2k, &mut m, 0, scratch.borrow());
|
||||||
m.normalize(log_base2k, 0, tmp_bytes_norm);
|
|
||||||
|
|
||||||
// m - BIG(ct[1] * s)
|
// m - BIG(ct[1] * s)
|
||||||
module.vec_znx_big_sub_small_b_inplace(
|
module.vec_znx_big_sub_small_b_inplace(
|
||||||
@@ -82,9 +78,12 @@ fn main() {
|
|||||||
// Normalizes back to VecZnx
|
// Normalizes back to VecZnx
|
||||||
// ct[0] <- m - BIG(c1 * s)
|
// ct[0] <- m - BIG(c1 * s)
|
||||||
module.vec_znx_big_normalize(
|
module.vec_znx_big_normalize(
|
||||||
log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0])
|
log_base2k,
|
||||||
&buf_big, 0, // Selects the first column of buf_big
|
&mut ct,
|
||||||
scratch,
|
0, // Selects the first column of ct (ct[0])
|
||||||
|
&buf_big,
|
||||||
|
0, // Selects the first column of buf_big
|
||||||
|
scratch.borrow(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add noise to ct[0]
|
// Add noise to ct[0]
|
||||||
@@ -104,7 +103,7 @@ fn main() {
|
|||||||
// Decryption
|
// Decryption
|
||||||
|
|
||||||
// DFT(ct[1] * s)
|
// DFT(ct[1] * s)
|
||||||
module.svp_apply_dft(
|
module.svp_apply(
|
||||||
&mut buf_dft,
|
&mut buf_dft,
|
||||||
0, // Selects the first column of res.
|
0, // Selects the first column of res.
|
||||||
&s_dft,
|
&s_dft,
|
||||||
@@ -114,14 +113,14 @@ fn main() {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// BIG(c1 * s) = IDFT(DFT(c1 * s))
|
// BIG(c1 * s) = IDFT(DFT(c1 * s))
|
||||||
<Module<_> as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0);
|
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||||
|
|
||||||
// BIG(c1 * s) + ct[0]
|
// BIG(c1 * s) + ct[0]
|
||||||
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
|
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
|
||||||
|
|
||||||
// m + e <- BIG(ct[1] * s + ct[0])
|
// m + e <- BIG(ct[1] * s + ct[0])
|
||||||
let mut res = module.new_vec_znx(1, ct_size);
|
let mut res = module.new_vec_znx(1, ct_size);
|
||||||
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch);
|
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch.borrow());
|
||||||
|
|
||||||
// have = m * 2^{log_scale} + e
|
// have = m * 2^{log_scale} + e
|
||||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
// use base2k::{
|
|
||||||
// Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
|
|
||||||
// ZnxInfos, ZnxLayout, alloc_aligned,
|
|
||||||
// };
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
// let log_n: i32 = 5;
|
|
||||||
// let n: usize = 1 << log_n;
|
|
||||||
|
|
||||||
// let module: Module<FFT64> = Module::<FFT64>::new(n);
|
|
||||||
// let log_base2k: usize = 15;
|
|
||||||
|
|
||||||
// let a_cols: usize = 2;
|
|
||||||
// let a_size: usize = 5;
|
|
||||||
|
|
||||||
// let log_k: usize = log_base2k * a_size - 5;
|
|
||||||
|
|
||||||
// let mat_rows: usize = a_size;
|
|
||||||
// let mat_cols_in: usize = a_cols;
|
|
||||||
// let mat_cols_out: usize = 2;
|
|
||||||
// let mat_size: usize = a_size + 1;
|
|
||||||
|
|
||||||
// let mut tmp_bytes_vmp: Vec<u8> = alloc_aligned(
|
|
||||||
// module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size)
|
|
||||||
// | module.vmp_apply_dft_tmp_bytes(
|
|
||||||
// a_size,
|
|
||||||
// a_size,
|
|
||||||
// mat_rows,
|
|
||||||
// mat_cols_in,
|
|
||||||
// mat_cols_out,
|
|
||||||
// mat_size,
|
|
||||||
// ),
|
|
||||||
// );
|
|
||||||
|
|
||||||
// let mut tmp_bytes_dft: Vec<u8> = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size));
|
|
||||||
|
|
||||||
// let mut a: VecZnx = module.new_vec_znx(a_cols, a_size);
|
|
||||||
|
|
||||||
// (0..a_cols).for_each(|i| {
|
|
||||||
// let mut values: Vec<i64> = vec![i64::default(); n];
|
|
||||||
// values[1 + i] = (1 << log_base2k) + 1;
|
|
||||||
// a.encode_vec_i64(i, log_base2k, log_k, &values, 32);
|
|
||||||
// a.normalize(log_base2k, i, &mut tmp_bytes_vmp);
|
|
||||||
// a.print(n, i);
|
|
||||||
// println!();
|
|
||||||
// });
|
|
||||||
|
|
||||||
// let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
|
||||||
|
|
||||||
// (0..a.size()).for_each(|row_i| {
|
|
||||||
// let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size);
|
|
||||||
// (0..mat_cols_out).for_each(|j| {
|
|
||||||
// tmp.at_mut(j, row_i)[1 + j] = 1 as i64;
|
|
||||||
// });
|
|
||||||
// (0..mat_cols_in).for_each(|j| {
|
|
||||||
// module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp);
|
|
||||||
// })
|
|
||||||
// });
|
|
||||||
|
|
||||||
// let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft);
|
|
||||||
// module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp);
|
|
||||||
|
|
||||||
// let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size);
|
|
||||||
// let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
|
|
||||||
// (0..mat_cols_out).for_each(|i| {
|
|
||||||
// module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
|
||||||
// module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp);
|
|
||||||
|
|
||||||
// let mut values_res: Vec<i64> = vec![i64::default(); n];
|
|
||||||
// res.decode_vec_i64(i, log_base2k, log_k, &mut values_res);
|
|
||||||
// res.print(n, i);
|
|
||||||
// println!();
|
|
||||||
// println!("{:?}", values_res);
|
|
||||||
// println!();
|
|
||||||
// });
|
|
||||||
|
|
||||||
// module.free();
|
|
||||||
}
|
|
||||||
@@ -215,4 +215,12 @@ impl Scratch {
|
|||||||
Self::new(rem_slice),
|
Self::new(rem_slice),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tmp_vec_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||||
|
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size));
|
||||||
|
(
|
||||||
|
VecZnx::from_data(take_slice, module.n(), cols, size),
|
||||||
|
Self::new(rem_slice),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned};
|
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||||
@@ -8,17 +8,17 @@ use std::marker::PhantomData;
|
|||||||
///
|
///
|
||||||
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
|
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
|
||||||
/// See the trait [MatZnxDftOps] for additional information.
|
/// See the trait [MatZnxDftOps] for additional information.
|
||||||
pub struct MatZnxDft<D, B> {
|
pub struct MatZnxDft<D, B: Backend> {
|
||||||
data: D,
|
data: D,
|
||||||
n: usize,
|
n: usize,
|
||||||
size: usize,
|
size: usize,
|
||||||
rows: usize,
|
rows: usize,
|
||||||
cols_in: usize,
|
cols_in: usize,
|
||||||
cols_out: usize,
|
cols_out: usize,
|
||||||
_marker: PhantomData<B>,
|
_phantom: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> ZnxInfos for MatZnxDft<D, B> {
|
impl<D, B: Backend> ZnxInfos for MatZnxDft<D, B> {
|
||||||
fn cols(&self) -> usize {
|
fn cols(&self) -> usize {
|
||||||
self.cols_in
|
self.cols_in
|
||||||
}
|
}
|
||||||
@@ -34,20 +34,22 @@ impl<D, B> ZnxInfos for MatZnxDft<D, B> {
|
|||||||
fn size(&self) -> usize {
|
fn size(&self) -> usize {
|
||||||
self.size
|
self.size
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D> ZnxSliceSize for MatZnxDft<D, FFT64> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.n()
|
self.n() * self.cols_out()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> DataView for MatZnxDft<D, B> {
|
impl<D, B: Backend> DataView for MatZnxDft<D, B> {
|
||||||
type D = D;
|
type D = D;
|
||||||
fn data(&self) -> &Self::D {
|
fn data(&self) -> &Self::D {
|
||||||
&self.data
|
&self.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> DataViewMut for MatZnxDft<D, B> {
|
impl<D, B: Backend> DataViewMut for MatZnxDft<D, B> {
|
||||||
fn data_mut(&mut self) -> &mut Self::D {
|
fn data_mut(&mut self) -> &mut Self::D {
|
||||||
&mut self.data
|
&mut self.data
|
||||||
}
|
}
|
||||||
@@ -57,7 +59,7 @@ impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> {
|
|||||||
type Scalar = f64;
|
type Scalar = f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> MatZnxDft<D, B> {
|
impl<D, B: Backend> MatZnxDft<D, B> {
|
||||||
pub(crate) fn cols_in(&self) -> usize {
|
pub(crate) fn cols_in(&self) -> usize {
|
||||||
self.cols_in
|
self.cols_in
|
||||||
}
|
}
|
||||||
@@ -87,7 +89,7 @@ impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
|
|||||||
rows,
|
rows,
|
||||||
cols_in,
|
cols_in,
|
||||||
cols_out,
|
cols_out,
|
||||||
_marker: PhantomData,
|
_phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +110,7 @@ impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
|
|||||||
rows,
|
rows,
|
||||||
cols_in,
|
cols_in,
|
||||||
cols_out,
|
cols_out,
|
||||||
_marker: PhantomData,
|
_phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -151,28 +153,80 @@ impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
|
|||||||
|
|
||||||
pub type MatZnxDftAllocOwned<B> = MatZnxDft<Vec<u8>, B>;
|
pub type MatZnxDftAllocOwned<B> = MatZnxDft<Vec<u8>, B>;
|
||||||
|
|
||||||
impl<B> MatZnxDft<Vec<u8>, B> {
|
pub trait MatZnxDftToRef<B: Backend> {
|
||||||
pub fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
fn to_ref(&self) -> MatZnxDft<&[u8], B>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait MatZnxDftToMut<B: Backend> {
|
||||||
|
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<Vec<u8>, B> {
|
||||||
|
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||||
MatZnxDft {
|
MatZnxDft {
|
||||||
data: self.data.as_mut_slice(),
|
data: self.data.as_mut_slice(),
|
||||||
n: self.n,
|
n: self.n,
|
||||||
size: self.size,
|
|
||||||
rows: self.rows,
|
rows: self.rows,
|
||||||
cols_in: self.cols_in,
|
cols_in: self.cols_in,
|
||||||
cols_out: self.cols_out,
|
cols_out: self.cols_out,
|
||||||
_marker: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
|
||||||
MatZnxDft {
|
|
||||||
data: self.data.as_slice(),
|
|
||||||
n: self.n,
|
|
||||||
size: self.size,
|
size: self.size,
|
||||||
rows: self.rows,
|
_phantom: PhantomData,
|
||||||
cols_in: self.cols_in,
|
}
|
||||||
cols_out: self.cols_out,
|
}
|
||||||
_marker: PhantomData,
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<Vec<u8>, B> {
|
||||||
|
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||||
|
MatZnxDft {
|
||||||
|
data: self.data.as_slice(),
|
||||||
|
n: self.n,
|
||||||
|
rows: self.rows,
|
||||||
|
cols_in: self.cols_in,
|
||||||
|
cols_out: self.cols_out,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<&mut [u8], B> {
|
||||||
|
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||||
|
MatZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
rows: self.rows,
|
||||||
|
cols_in: self.cols_in,
|
||||||
|
cols_out: self.cols_out,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&mut [u8], B> {
|
||||||
|
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||||
|
MatZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
rows: self.rows,
|
||||||
|
cols_in: self.cols_in,
|
||||||
|
cols_out: self.cols_out,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&[u8], B> {
|
||||||
|
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||||
|
MatZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
rows: self.rows,
|
||||||
|
cols_in: self.cols_in,
|
||||||
|
cols_out: self.cols_out,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
|||||||
use crate::ffi::vmp;
|
use crate::ffi::vmp;
|
||||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||||
use crate::{
|
use crate::{
|
||||||
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, Scratch, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
|
||||||
VecZnxDftAlloc, VecZnxDftOps,
|
VecZnxDftToRef,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub trait MatZnxDftAlloc<B> {
|
pub trait MatZnxDftAlloc<B: Backend> {
|
||||||
/// Allocates a new [MatZnxDft] with the given number of rows and columns.
|
/// Allocates a new [MatZnxDft] with the given number of rows and columns.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -28,43 +28,10 @@ pub trait MatZnxDftAlloc<B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub trait MatZnxDftScratch {
|
pub trait MatZnxDftScratch {
|
||||||
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row]
|
|
||||||
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
|
||||||
|
|
||||||
/// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row]
|
|
||||||
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize;
|
|
||||||
|
|
||||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
|
||||||
/// * `a_size`: number of size of the input [VecZnx].
|
|
||||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
|
||||||
/// * `size`: number of size of the input [MatZnxDft].
|
|
||||||
fn vmp_apply_dft_tmp_bytes(
|
|
||||||
&self,
|
|
||||||
c_size: usize,
|
|
||||||
a_size: usize,
|
|
||||||
b_rows: usize,
|
|
||||||
b_cols_in: usize,
|
|
||||||
b_cols_out: usize,
|
|
||||||
b_size: usize,
|
|
||||||
) -> usize;
|
|
||||||
|
|
||||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
||||||
///
|
fn vmp_apply_tmp_bytes(
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `c_size`: number of size of the output [VecZnxDft].
|
|
||||||
/// * `a_size`: number of size of the input [VecZnxDft].
|
|
||||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
|
||||||
/// * `size`: number of size of the input [MatZnxDft].
|
|
||||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
|
||||||
&self,
|
&self,
|
||||||
c_cols: usize,
|
res_size: usize,
|
||||||
c_size: usize,
|
|
||||||
a_cols: usize,
|
|
||||||
a_size: usize,
|
a_size: usize,
|
||||||
b_rows: usize,
|
b_rows: usize,
|
||||||
b_cols_in: usize,
|
b_cols_in: usize,
|
||||||
@@ -75,43 +42,7 @@ pub trait MatZnxDftScratch {
|
|||||||
|
|
||||||
/// This trait implements methods for vector matrix product,
|
/// This trait implements methods for vector matrix product,
|
||||||
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
||||||
pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
pub trait MatZnxDftOps<BACKEND: Backend> {
|
||||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnx].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `b`: [MatZnxDft] on which the values are encoded.
|
|
||||||
/// * `row_i`: the row of the [MatZnxDft] to prepare.
|
|
||||||
/// * `a`: the [VecZnx] to encode on the i-th row of the [MatZnxDft].
|
|
||||||
/// * `buf`: scratch space, the size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
|
||||||
///
|
|
||||||
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
|
||||||
fn vmp_prepare_row(
|
|
||||||
&self,
|
|
||||||
b: &mut MatZnxDft<DataMut, B>,
|
|
||||||
b_row: usize,
|
|
||||||
b_col_in: usize,
|
|
||||||
a: &VecZnx<Data>,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `b`: the [VecZnxBig] to on which to extract the row of the [MatZnxDft].
|
|
||||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
|
||||||
/// * `row_i`: the index of the row to extract.
|
|
||||||
fn vmp_extract_row(
|
|
||||||
&self,
|
|
||||||
log_base2k: usize,
|
|
||||||
b: &mut VecZnx<DataMut>,
|
|
||||||
a: &MatZnxDft<Data, B>,
|
|
||||||
b_row: usize,
|
|
||||||
b_col_in: usize,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -121,7 +52,10 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
|||||||
/// * `row_i`: the index of the row to prepare.
|
/// * `row_i`: the index of the row to prepare.
|
||||||
///
|
///
|
||||||
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||||
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<DataMut, B>, b_row: usize, b_col_in: usize, a: &VecZnxDft<Data, B>);
|
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
|
||||||
|
where
|
||||||
|
R: MatZnxDftToMut<BACKEND>,
|
||||||
|
A: VecZnxDftToRef<BACKEND>;
|
||||||
|
|
||||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
|
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
|
||||||
///
|
///
|
||||||
@@ -130,33 +64,10 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
|||||||
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
||||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
/// * `a`: [MatZnxDft] on which the values are encoded.
|
||||||
/// * `row_i`: the index of the row to extract.
|
/// * `row_i`: the index of the row to extract.
|
||||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<DataMut, B>, a: &MatZnxDft<Data, B>, a_row: usize, a_col_in: usize);
|
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
|
||||||
|
where
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
R: VecZnxDftToMut<BACKEND>,
|
||||||
///
|
A: MatZnxDftToRef<BACKEND>;
|
||||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
|
||||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
|
||||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
|
||||||
///
|
|
||||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
|
||||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
|
||||||
///
|
|
||||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
|
||||||
///
|
|
||||||
/// ```text
|
|
||||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
|
||||||
/// |h i j|
|
|
||||||
/// |k l m|
|
|
||||||
/// ```
|
|
||||||
/// where each element is a [VecZnxDft].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
|
|
||||||
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
|
||||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
|
||||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
|
||||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<DataMut, B>, a: &VecZnx<Data>, b: &MatZnxDft<Data, B>, scratch: &mut Scratch);
|
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
@@ -183,13 +94,11 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
|||||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||||
fn vmp_apply_dft_to_dft(
|
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||||
&self,
|
where
|
||||||
c: &mut VecZnxDft<DataMut, B>,
|
R: VecZnxDftToMut<BACKEND>,
|
||||||
a: &VecZnxDft<Data, B>,
|
A: VecZnxDftToRef<BACKEND>,
|
||||||
b: &MatZnxDft<Data, B>,
|
B: MatZnxDftToRef<BACKEND>;
|
||||||
scratch: &mut Scratch,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
|
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
|
||||||
@@ -213,40 +122,10 @@ impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> MatZnxDftScratch for Module<B> {
|
impl<BACKEND: Backend> MatZnxDftScratch for Module<BACKEND> {
|
||||||
fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
fn vmp_apply_tmp_bytes(
|
||||||
<Self as VecZnxDftAlloc<_>>::bytes_of_vec_znx_dft(self, cols_out, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize {
|
|
||||||
<Self as VecZnxDftAlloc<_>>::bytes_of_vec_znx_dft(self, cols_out, size)
|
|
||||||
+ <Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(self)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_apply_dft_tmp_bytes(
|
|
||||||
&self,
|
&self,
|
||||||
c_size: usize,
|
res_size: usize,
|
||||||
a_size: usize,
|
|
||||||
b_rows: usize,
|
|
||||||
b_cols_in: usize,
|
|
||||||
b_cols_out: usize,
|
|
||||||
b_size: usize,
|
|
||||||
) -> usize {
|
|
||||||
unsafe {
|
|
||||||
vmp::vmp_apply_dft_tmp_bytes(
|
|
||||||
self.ptr,
|
|
||||||
c_size as u64,
|
|
||||||
a_size as u64,
|
|
||||||
(b_rows * b_cols_in) as u64,
|
|
||||||
(b_size * b_cols_out) as u64,
|
|
||||||
) as usize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
|
||||||
&self,
|
|
||||||
c_cols: usize,
|
|
||||||
c_size: usize,
|
|
||||||
a_cols: usize,
|
|
||||||
a_size: usize,
|
a_size: usize,
|
||||||
b_rows: usize,
|
b_rows: usize,
|
||||||
b_cols_in: usize,
|
b_cols_in: usize,
|
||||||
@@ -256,8 +135,8 @@ impl<B: Backend> MatZnxDftScratch for Module<B> {
|
|||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
(c_size * c_cols) as u64,
|
(res_size * b_cols_out) as u64,
|
||||||
(a_size * a_cols) as u64,
|
(a_size * b_cols_in) as u64,
|
||||||
(b_rows * b_cols_in) as u64,
|
(b_rows * b_cols_in) as u64,
|
||||||
(b_size * b_cols_out) as u64,
|
(b_size * b_cols_out) as u64,
|
||||||
) as usize
|
) as usize
|
||||||
@@ -265,152 +144,43 @@ impl<B: Backend> MatZnxDftScratch for Module<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module<FFT64> {
|
impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||||
fn vmp_prepare_row(
|
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
|
||||||
&self,
|
where
|
||||||
b: &mut MatZnxDft<&mut [u8], FFT64>,
|
R: MatZnxDftToMut<FFT64>,
|
||||||
b_row: usize,
|
A: VecZnxDftToRef<FFT64>,
|
||||||
b_col_in: usize,
|
{
|
||||||
a: &VecZnx<&[u8]>,
|
let mut res: MatZnxDft<&mut [u8], _> = res.to_mut();
|
||||||
scratch: &mut Scratch,
|
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||||
) {
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(b.n(), self.n());
|
assert_eq!(res.n(), self.n());
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
a.cols(),
|
a.cols(),
|
||||||
b.cols_out(),
|
res.cols_out(),
|
||||||
"a.cols(): {} != b.cols_out(): {}",
|
"a.cols(): {} != res.cols_out(): {}",
|
||||||
a.cols(),
|
a.cols(),
|
||||||
b.cols_out()
|
res.cols_out()
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
b_row < b.rows(),
|
res_row < res.rows(),
|
||||||
"b_row: {} >= b.rows(): {}",
|
"res_row: {} >= res.rows(): {}",
|
||||||
b_row,
|
res_row,
|
||||||
b.rows()
|
res.rows()
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
b_col_in < b.cols_in(),
|
res_col_in < res.cols_in(),
|
||||||
"b_col_in: {} >= b.cols_in(): {}",
|
"res_col_in: {} >= res.cols_in(): {}",
|
||||||
b_col_in,
|
res_col_in,
|
||||||
b.cols_in()
|
res.cols_in()
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
b.size(),
|
res.size(),
|
||||||
a.size(),
|
a.size(),
|
||||||
"b.size(): {} != a.size(): {}",
|
"res.size(): {} != a.size(): {}",
|
||||||
b.size(),
|
res.size(),
|
||||||
a.size()
|
|
||||||
);
|
|
||||||
// assert!(
|
|
||||||
// tmp_bytes.len()
|
|
||||||
// >= <Self as MatZnxDftOps<DataMut, Data, FFT64>>::vmp_prepare_row_tmp_bytes(self, a.cols(), a.size())
|
|
||||||
// );
|
|
||||||
// assert!(is_aligned(tmp_bytes.as_ptr()))
|
|
||||||
}
|
|
||||||
|
|
||||||
let cols_out: usize = a.cols();
|
|
||||||
let a_size: usize = a.size();
|
|
||||||
|
|
||||||
// let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size));
|
|
||||||
let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<_>(self, cols_out, a_size);
|
|
||||||
(0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i));
|
|
||||||
Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft.to_ref());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_extract_row(
|
|
||||||
&self,
|
|
||||||
log_base2k: usize,
|
|
||||||
b: &mut VecZnx<&mut [u8]>,
|
|
||||||
a: &MatZnxDft<&[u8], FFT64>,
|
|
||||||
a_row: usize,
|
|
||||||
a_col_in: usize,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
) {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(b.n(), self.n());
|
|
||||||
assert_eq!(a.n(), self.n());
|
|
||||||
assert_eq!(
|
|
||||||
b.cols(),
|
|
||||||
a.cols_out(),
|
|
||||||
"b.cols(): {} != a.cols_out(): {}",
|
|
||||||
b.cols(),
|
|
||||||
a.cols_out()
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
a_row < a.rows(),
|
|
||||||
"a_row: {} >= a.rows(): {}",
|
|
||||||
a_row,
|
|
||||||
a.rows()
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
a_col_in < a.cols_in(),
|
|
||||||
"a_col_in: {} >= a.cols_in(): {}",
|
|
||||||
a_col_in,
|
|
||||||
a.cols_in()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
b.size(),
|
|
||||||
a.size(),
|
|
||||||
"b.size(): {} != a.size(): {}",
|
|
||||||
b.size(),
|
|
||||||
a.size()
|
|
||||||
);
|
|
||||||
// assert!(tmp_bytes.len() >= self.vmp_extract_row_tmp_bytes(a.cols(), a.size()));
|
|
||||||
// assert!(is_aligned(tmp_bytes.as_ptr()))
|
|
||||||
}
|
|
||||||
|
|
||||||
let cols_out: usize = b.cols();
|
|
||||||
let size: usize = b.size();
|
|
||||||
|
|
||||||
// let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size));
|
|
||||||
let (mut b_dft, scratch) = scratch.tmp_vec_znx_dft(self, cols_out, size);
|
|
||||||
Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in);
|
|
||||||
let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size);
|
|
||||||
(0..cols_out).for_each(|i| {
|
|
||||||
<Self as VecZnxDftOps<&mut [u8], &[u8], FFT64>>::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i);
|
|
||||||
self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn vmp_prepare_row_dft(
|
|
||||||
&self,
|
|
||||||
b: &mut MatZnxDft<&mut [u8], FFT64>,
|
|
||||||
b_row: usize,
|
|
||||||
b_col_in: usize,
|
|
||||||
a: &VecZnxDft<&[u8], FFT64>,
|
|
||||||
) {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(b.n(), self.n());
|
|
||||||
assert_eq!(a.n(), self.n());
|
|
||||||
assert_eq!(
|
|
||||||
a.cols(),
|
|
||||||
b.cols_out(),
|
|
||||||
"a.cols(): {} != b.cols_out(): {}",
|
|
||||||
a.cols(),
|
|
||||||
b.cols_out()
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
b_row < b.rows(),
|
|
||||||
"b_row: {} >= b.rows(): {}",
|
|
||||||
b_row,
|
|
||||||
b.rows()
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
b_col_in < b.cols_in(),
|
|
||||||
"b_col_in: {} >= b.cols_in(): {}",
|
|
||||||
b_col_in,
|
|
||||||
b.cols_in()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
b.size(),
|
|
||||||
a.size(),
|
|
||||||
"b.size(): {} != a.size(): {}",
|
|
||||||
b.size(),
|
|
||||||
a.size()
|
a.size()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -418,31 +188,32 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module<FFT64> {
|
|||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_prepare_row_dft(
|
vmp::vmp_prepare_row_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
b.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
||||||
a.as_ptr() as *const vec_znx_dft_t,
|
a.as_ptr() as *const vec_znx_dft_t,
|
||||||
(b_row * b.cols_in() + b_col_in) as u64,
|
(res_row * res.cols_in() + res_col_in) as u64,
|
||||||
(b.rows() * b.cols_in()) as u64,
|
(res.rows() * res.cols_in()) as u64,
|
||||||
(b.size() * b.cols_out()) as u64,
|
(res.size() * res.cols_out()) as u64,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_extract_row_dft(
|
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
|
||||||
&self,
|
where
|
||||||
b: &mut VecZnxDft<&mut [u8], FFT64>,
|
R: VecZnxDftToMut<FFT64>,
|
||||||
a: &MatZnxDft<&[u8], FFT64>,
|
A: MatZnxDftToRef<FFT64>,
|
||||||
a_row: usize,
|
{
|
||||||
a_col_in: usize,
|
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||||
) {
|
let a: MatZnxDft<&[u8], _> = a.to_ref();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(b.n(), self.n());
|
assert_eq!(res.n(), self.n());
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
b.cols(),
|
res.cols(),
|
||||||
a.cols_out(),
|
a.cols_out(),
|
||||||
"b.cols(): {} != a.cols_out(): {}",
|
"res.cols(): {} != a.cols_out(): {}",
|
||||||
b.cols(),
|
res.cols(),
|
||||||
a.cols_out()
|
a.cols_out()
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
@@ -458,17 +229,17 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module<FFT64> {
|
|||||||
a.cols_in()
|
a.cols_in()
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
b.size(),
|
res.size(),
|
||||||
a.size(),
|
a.size(),
|
||||||
"b.size(): {} != a.size(): {}",
|
"res.size(): {} != a.size(): {}",
|
||||||
b.size(),
|
res.size(),
|
||||||
a.size()
|
a.size()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_extract_row_dft(
|
vmp::vmp_extract_row_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
b.as_mut_ptr() as *mut vec_znx_dft_t,
|
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||||
a.as_ptr() as *const vmp::vmp_pmat_t,
|
a.as_ptr() as *const vmp::vmp_pmat_t,
|
||||||
(a_row * a.cols_in() + a_col_in) as u64,
|
(a_row * a.cols_in() + a_col_in) as u64,
|
||||||
(a.rows() * a.cols_in()) as u64,
|
(a.rows() * a.cols_in()) as u64,
|
||||||
@@ -477,23 +248,26 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module<FFT64> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft(
|
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||||
&self,
|
where
|
||||||
c: &mut VecZnxDft<&mut [u8], FFT64>,
|
R: VecZnxDftToMut<FFT64>,
|
||||||
a: &VecZnx<&[u8]>,
|
A: VecZnxDftToRef<FFT64>,
|
||||||
b: &MatZnxDft<&[u8], FFT64>,
|
B: MatZnxDftToRef<FFT64>,
|
||||||
scratch: &mut Scratch,
|
{
|
||||||
) {
|
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||||
|
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||||
|
let b: MatZnxDft<&[u8], _> = b.to_ref();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(c.n(), self.n());
|
assert_eq!(res.n(), self.n());
|
||||||
assert_eq!(b.n(), self.n());
|
assert_eq!(b.n(), self.n());
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
c.cols(),
|
res.cols(),
|
||||||
b.cols_out(),
|
b.cols_out(),
|
||||||
"c.cols(): {} != b.cols_out: {}",
|
"res.cols(): {} != b.cols_out: {}",
|
||||||
c.cols(),
|
res.cols(),
|
||||||
b.cols_out()
|
b.cols_out()
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -503,92 +277,10 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module<FFT64> {
|
|||||||
a.cols(),
|
a.cols(),
|
||||||
b.cols_in()
|
b.cols_in()
|
||||||
);
|
);
|
||||||
// assert!(
|
|
||||||
// tmp_bytes.len()
|
|
||||||
// >= self.vmp_apply_dft_tmp_bytes(
|
|
||||||
// c.size(),
|
|
||||||
// a.size(),
|
|
||||||
// b.rows(),
|
|
||||||
// b.cols_in(),
|
|
||||||
// b.cols_out(),
|
|
||||||
// b.size()
|
|
||||||
// )
|
|
||||||
// );
|
|
||||||
// assert_alignement(tmp_bytes.as_ptr());
|
|
||||||
}
|
|
||||||
let (tmp_bytes, _) = scratch.tmp_scalar_slice(<Self as MatZnxDftScratch>::vmp_apply_dft_tmp_bytes(
|
|
||||||
self,
|
|
||||||
c.size(),
|
|
||||||
a.size(),
|
|
||||||
b.rows(),
|
|
||||||
b.cols_in(),
|
|
||||||
b.cols_out(),
|
|
||||||
b.size(),
|
|
||||||
));
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
vmp::vmp_apply_dft(
|
|
||||||
self.ptr,
|
|
||||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
|
||||||
(c.size() * c.cols()) as u64,
|
|
||||||
a.as_ptr(),
|
|
||||||
(a.size() * a.cols()) as u64,
|
|
||||||
a.n() as u64,
|
|
||||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
|
||||||
(b.rows() * b.cols_in()) as u64,
|
|
||||||
(b.size() * b.cols_out()) as u64,
|
|
||||||
tmp_bytes.as_mut_ptr(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_to_dft(
|
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_tmp_bytes(
|
||||||
&self,
|
res.size(),
|
||||||
c: &mut VecZnxDft<&mut [u8], FFT64>,
|
|
||||||
a: &VecZnxDft<&[u8], FFT64>,
|
|
||||||
b: &MatZnxDft<&[u8], FFT64>,
|
|
||||||
scratch: &mut Scratch,
|
|
||||||
) {
|
|
||||||
{
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(c.n(), self.n());
|
|
||||||
assert_eq!(b.n(), self.n());
|
|
||||||
assert_eq!(a.n(), self.n());
|
|
||||||
assert_eq!(
|
|
||||||
c.cols(),
|
|
||||||
b.cols_out(),
|
|
||||||
"c.cols(): {} != b.cols_out: {}",
|
|
||||||
c.cols(),
|
|
||||||
b.cols_out()
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
a.cols(),
|
|
||||||
b.cols_in(),
|
|
||||||
"a.cols(): {} != b.cols_in: {}",
|
|
||||||
a.cols(),
|
|
||||||
b.cols_in()
|
|
||||||
);
|
|
||||||
// assert!(
|
|
||||||
// tmp_bytes.len()
|
|
||||||
// >= self.vmp_apply_dft_to_dft_tmp_bytes(
|
|
||||||
// c.cols(),
|
|
||||||
// c.size(),
|
|
||||||
// a.cols(),
|
|
||||||
// a.size(),
|
|
||||||
// b.rows(),
|
|
||||||
// b.cols_in(),
|
|
||||||
// b.cols_out(),
|
|
||||||
// b.size()
|
|
||||||
// )
|
|
||||||
// );
|
|
||||||
// assert_alignement(tmp_bytes.as_ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes(
|
|
||||||
c.cols(),
|
|
||||||
c.size(),
|
|
||||||
a.cols(),
|
|
||||||
a.size(),
|
a.size(),
|
||||||
b.rows(),
|
b.rows(),
|
||||||
b.cols_in(),
|
b.cols_in(),
|
||||||
@@ -598,107 +290,142 @@ impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module<FFT64> {
|
|||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft_to_dft(
|
vmp::vmp_apply_dft_to_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
c.as_mut_ptr() as *mut vec_znx_dft_t,
|
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||||
c.poly_count() as u64,
|
(res.size() * res.cols()) as u64,
|
||||||
a.as_ptr() as *const vec_znx_dft_t,
|
a.as_ptr() as *const vec_znx_dft_t,
|
||||||
a.poly_count() as u64,
|
(a.size() * a.cols()) as u64,
|
||||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||||
b.rows() as u64,
|
(b.rows() * b.cols_in()) as u64,
|
||||||
(b.size() * b.cols()) as u64,
|
(b.size() * b.cols_out()) as u64,
|
||||||
tmp_bytes.as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::ScratchOwned;
|
|
||||||
use crate::mat_znx_dft_ops::*;
|
|
||||||
use crate::vec_znx_big_ops::*;
|
|
||||||
use crate::vec_znx_dft_ops::*;
|
|
||||||
use crate::vec_znx_ops::*;
|
|
||||||
use crate::{
|
use crate::{
|
||||||
FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, alloc_aligned,
|
Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, Sampling, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
|
||||||
|
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
|
use super::{MatZnxDftAlloc, MatZnxDftScratch};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vmp_prepare_row_dft() {
|
fn vmp_prepare_row() {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(16);
|
let module: Module<FFT64> = Module::<FFT64>::new(16);
|
||||||
let log_base2k: usize = 8;
|
let log_base2k: usize = 8;
|
||||||
let mat_rows: usize = 4;
|
let mat_rows: usize = 4;
|
||||||
let mat_cols_in: usize = 2;
|
let mat_cols_in: usize = 2;
|
||||||
let mat_cols_out: usize = 2;
|
let mat_cols_out: usize = 2;
|
||||||
let mat_size: usize = 5;
|
let mat_size: usize = 5;
|
||||||
let mut a: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size);
|
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||||
let mut b: VecZnx<_> = module.new_vec_znx(mat_cols_out, mat_size);
|
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||||
let mut a_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
let mut b_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||||
let mut a_big: VecZnxBig<_, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
let mut mat: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||||
let mut b_dft: VecZnxDft<_, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
|
||||||
let mut vmpmat_0: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
|
||||||
let mut vmpmat_1: MatZnxDft<_, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
|
||||||
|
|
||||||
// let mut tmp_bytes: Vec<u8> =
|
|
||||||
// alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes());
|
|
||||||
let mut scratch = ScratchOwned::new(
|
|
||||||
2 * (module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + module.vec_znx_big_normalize_tmp_bytes()),
|
|
||||||
);
|
|
||||||
let mut tmp_bytes: Vec<u8> =
|
|
||||||
alloc_aligned::<u8>(<Module<FFT64> as VecZnxDftOps<Vec<u8>, Vec<u8>, _>>::vec_znx_idft_tmp_bytes(&module));
|
|
||||||
|
|
||||||
for col_in in 0..mat_cols_in {
|
for col_in in 0..mat_cols_in {
|
||||||
for row_i in 0..mat_rows {
|
for row_i in 0..mat_rows {
|
||||||
let mut source: Source = Source::new([0u8; 32]);
|
let mut source: Source = Source::new([0u8; 32]);
|
||||||
|
|
||||||
(0..mat_cols_out).for_each(|col_out| {
|
(0..mat_cols_out).for_each(|col_out| {
|
||||||
module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source);
|
module.fill_uniform(log_base2k, &mut a, col_out, mat_size, &mut source);
|
||||||
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
|
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
|
||||||
});
|
});
|
||||||
|
module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft);
|
||||||
module.vmp_prepare_row(
|
module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in);
|
||||||
&mut vmpmat_0.to_mut(),
|
|
||||||
row_i,
|
|
||||||
col_in,
|
|
||||||
&a.to_ref(),
|
|
||||||
scratch.borrow(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
|
|
||||||
module.vmp_prepare_row_dft(&mut vmpmat_1.to_mut(), row_i, col_in, &a_dft.to_ref());
|
|
||||||
assert_eq!(vmpmat_0.raw(), vmpmat_1.raw());
|
|
||||||
|
|
||||||
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft)
|
|
||||||
module.vmp_extract_row_dft(&mut b_dft.to_mut(), &vmpmat_0.to_ref(), row_i, col_in);
|
|
||||||
assert_eq!(a_dft.raw(), b_dft.raw());
|
assert_eq!(a_dft.raw(), b_dft.raw());
|
||||||
|
|
||||||
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
|
||||||
module.vmp_extract_row(
|
|
||||||
log_base2k,
|
|
||||||
&mut b.to_mut(),
|
|
||||||
&vmpmat_0.to_ref(),
|
|
||||||
row_i,
|
|
||||||
col_in,
|
|
||||||
scratch.borrow(),
|
|
||||||
);
|
|
||||||
|
|
||||||
(0..mat_cols_out).for_each(|col_out| {
|
|
||||||
module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes);
|
|
||||||
module.vec_znx_big_normalize(
|
|
||||||
log_base2k,
|
|
||||||
&mut a.to_mut(),
|
|
||||||
col_out,
|
|
||||||
&a_big.to_ref(),
|
|
||||||
col_out,
|
|
||||||
scratch.borrow(),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(a.raw(), b.raw());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module.free();
|
module.free();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vmp_apply() {
|
||||||
|
let log_n: i32 = 5;
|
||||||
|
let n: usize = 1 << log_n;
|
||||||
|
|
||||||
|
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||||
|
let log_base2k: usize = 15;
|
||||||
|
let a_size: usize = 5;
|
||||||
|
let mat_size: usize = 6;
|
||||||
|
let res_size: usize = 5;
|
||||||
|
|
||||||
|
[1, 2].iter().for_each(|in_cols| {
|
||||||
|
[1, 2].iter().for_each(|out_cols| {
|
||||||
|
let a_cols: usize = *in_cols;
|
||||||
|
let res_cols: usize = *out_cols;
|
||||||
|
|
||||||
|
let mat_rows: usize = a_size;
|
||||||
|
let mat_cols_in: usize = a_cols;
|
||||||
|
let mat_cols_out: usize = res_cols;
|
||||||
|
let res_cols: usize = mat_cols_out;
|
||||||
|
|
||||||
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
|
module.vmp_apply_tmp_bytes(
|
||||||
|
res_size,
|
||||||
|
a_size,
|
||||||
|
mat_rows,
|
||||||
|
mat_cols_in,
|
||||||
|
mat_cols_out,
|
||||||
|
mat_size,
|
||||||
|
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||||
|
|
||||||
|
(0..a_cols).for_each(|i| {
|
||||||
|
a.at_mut(i, 2)[i + 1] = 1;
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||||
|
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||||
|
|
||||||
|
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||||
|
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||||
|
|
||||||
|
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||||
|
|
||||||
|
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||||
|
(0..a.size()).for_each(|row_i| {
|
||||||
|
(0..mat_cols_in).for_each(|col_in_i| {
|
||||||
|
(0..mat_cols_out).for_each(|col_out_i| {
|
||||||
|
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||||
|
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||||
|
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
|
||||||
|
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||||
|
});
|
||||||
|
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||||
|
(0..a_cols).for_each(|i| {
|
||||||
|
module.vec_znx_dft(&mut a_dft, i, &a, i);
|
||||||
|
});
|
||||||
|
|
||||||
|
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||||
|
|
||||||
|
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||||
|
|
||||||
|
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
|
||||||
|
(0..mat_cols_out).for_each(|i| {
|
||||||
|
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||||
|
module.vec_znx_big_normalize(log_base2k, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||||
|
});
|
||||||
|
|
||||||
|
(0..mat_cols_out).for_each(|col_i| {
|
||||||
|
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
||||||
|
(0..a_cols).for_each(|i| {
|
||||||
|
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
|
||||||
|
});
|
||||||
|
res_have.decode_vec_i64(col_i, log_base2k, log_base2k * 3, &mut res_have_vi64);
|
||||||
|
assert_eq!(res_have_vi64, res_want_vi64);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
module.free();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,53 +1,47 @@
|
|||||||
use crate::znx_base::ZnxViewMut;
|
use crate::znx_base::ZnxViewMut;
|
||||||
use crate::{Backend, Module, VecZnx};
|
use crate::{Backend, Module, VecZnx, VecZnxToMut};
|
||||||
use rand_distr::{Distribution, Normal};
|
use rand_distr::{Distribution, Normal};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
pub trait Sampling {
|
pub trait Sampling {
|
||||||
/// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
|
/// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
|
||||||
fn fill_uniform<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
|
fn fill_uniform<A>(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source)
|
||||||
&self,
|
where
|
||||||
log_base2k: usize,
|
A: VecZnxToMut;
|
||||||
a: &mut VecZnx<DataMut>,
|
|
||||||
col_i: usize,
|
|
||||||
size: usize,
|
|
||||||
source: &mut Source,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
|
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
|
||||||
fn add_dist_f64<DataMut: AsMut<[u8]> + AsRef<[u8]>, D: Distribution<f64>>(
|
fn add_dist_f64<A, D: Distribution<f64>>(
|
||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
a: &mut VecZnx<DataMut>,
|
a: &mut A,
|
||||||
col_i: usize,
|
col_i: usize,
|
||||||
log_k: usize,
|
log_k: usize,
|
||||||
source: &mut Source,
|
source: &mut Source,
|
||||||
dist: D,
|
dist: D,
|
||||||
bound: f64,
|
bound: f64,
|
||||||
);
|
) where
|
||||||
|
A: VecZnxToMut;
|
||||||
|
|
||||||
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
|
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
|
||||||
fn add_normal<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
|
fn add_normal<A>(
|
||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
a: &mut VecZnx<DataMut>,
|
a: &mut A,
|
||||||
col_i: usize,
|
col_i: usize,
|
||||||
log_k: usize,
|
log_k: usize,
|
||||||
source: &mut Source,
|
source: &mut Source,
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
bound: f64,
|
bound: f64,
|
||||||
);
|
) where
|
||||||
|
A: VecZnxToMut;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Sampling for Module<B> {
|
impl<B: Backend> Sampling for Module<B> {
|
||||||
fn fill_uniform<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
|
fn fill_uniform<A>(&self, log_base2k: usize, a: &mut A, col_i: usize, size: usize, source: &mut Source)
|
||||||
&self,
|
where
|
||||||
log_base2k: usize,
|
A: VecZnxToMut,
|
||||||
a: &mut VecZnx<DataMut>,
|
{
|
||||||
col_i: usize,
|
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||||
size: usize,
|
|
||||||
source: &mut Source,
|
|
||||||
) {
|
|
||||||
let base2k: u64 = 1 << log_base2k;
|
let base2k: u64 = 1 << log_base2k;
|
||||||
let mask: u64 = base2k - 1;
|
let mask: u64 = base2k - 1;
|
||||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||||
@@ -58,16 +52,19 @@ impl<B: Backend> Sampling for Module<B> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_dist_f64<DataMut: AsMut<[u8]> + AsRef<[u8]>, D: Distribution<f64>>(
|
fn add_dist_f64<A, D: Distribution<f64>>(
|
||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
a: &mut VecZnx<DataMut>,
|
a: &mut A,
|
||||||
col_i: usize,
|
col_i: usize,
|
||||||
log_k: usize,
|
log_k: usize,
|
||||||
source: &mut Source,
|
source: &mut Source,
|
||||||
dist: D,
|
dist: D,
|
||||||
bound: f64,
|
bound: f64,
|
||||||
) {
|
) where
|
||||||
|
A: VecZnxToMut,
|
||||||
|
{
|
||||||
|
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||||
assert!(
|
assert!(
|
||||||
(bound.log2().ceil() as i64) < 64,
|
(bound.log2().ceil() as i64) < 64,
|
||||||
"invalid bound: ceil(log2(bound))={} > 63",
|
"invalid bound: ceil(log2(bound))={} > 63",
|
||||||
@@ -96,16 +93,10 @@ impl<B: Backend> Sampling for Module<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_normal<DataMut: AsMut<[u8]> + AsRef<[u8]>>(
|
fn add_normal<A>(&self, log_base2k: usize, a: &mut A, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64)
|
||||||
&self,
|
where
|
||||||
log_base2k: usize,
|
A: VecZnxToMut,
|
||||||
a: &mut VecZnx<DataMut>,
|
{
|
||||||
col_i: usize,
|
|
||||||
log_k: usize,
|
|
||||||
source: &mut Source,
|
|
||||||
sigma: f64,
|
|
||||||
bound: f64,
|
|
||||||
) {
|
|
||||||
self.add_dist_f64(
|
self.add_dist_f64(
|
||||||
log_base2k,
|
log_base2k,
|
||||||
a,
|
a,
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::{Backend, DataView, DataViewMut, Module, ZnxView, ZnxViewMut, alloc_aligned};
|
use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned};
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand_core::RngCore;
|
use rand_core::RngCore;
|
||||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
// pub const SCALAR_ZNX_ROWS: usize = 1;
|
|
||||||
// pub const SCALAR_ZNX_SIZE: usize = 1;
|
|
||||||
|
|
||||||
pub struct Scalar<D> {
|
pub struct Scalar<D> {
|
||||||
data: D,
|
data: D,
|
||||||
n: usize,
|
n: usize,
|
||||||
@@ -30,7 +27,9 @@ impl<D> ZnxInfos for Scalar<D> {
|
|||||||
fn size(&self) -> usize {
|
fn size(&self) -> usize {
|
||||||
1
|
1
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D> ZnxSliceSize for Scalar<D> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.n()
|
self.n()
|
||||||
}
|
}
|
||||||
@@ -70,19 +69,6 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> Scalar<D> {
|
|||||||
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
||||||
self.at_mut(col, 0).shuffle(source);
|
self.at_mut(col, 0).shuffle(source);
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn alias_as_vec_znx(&self) -> VecZnx {
|
|
||||||
// VecZnx {
|
|
||||||
// inner: ZnxBase {
|
|
||||||
// n: self.n(),
|
|
||||||
// rows: 1,
|
|
||||||
// cols: 1,
|
|
||||||
// size: 1,
|
|
||||||
// data: Vec::new(),
|
|
||||||
// ptr: self.ptr() as *mut u8,
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: From<Vec<u8>>> Scalar<D> {
|
impl<D: From<Vec<u8>>> Scalar<D> {
|
||||||
@@ -116,7 +102,6 @@ pub trait ScalarAlloc {
|
|||||||
fn bytes_of_scalar(&self, cols: usize) -> usize;
|
fn bytes_of_scalar(&self, cols: usize) -> usize;
|
||||||
fn new_scalar(&self, cols: usize) -> ScalarOwned;
|
fn new_scalar(&self, cols: usize) -> ScalarOwned;
|
||||||
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarOwned;
|
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarOwned;
|
||||||
// fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ScalarAlloc for Module<B> {
|
impl<B: Backend> ScalarAlloc for Module<B> {
|
||||||
@@ -129,31 +114,62 @@ impl<B: Backend> ScalarAlloc for Module<B> {
|
|||||||
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarOwned {
|
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarOwned {
|
||||||
ScalarOwned::new_from_bytes::<i64>(self.n(), cols, bytes)
|
ScalarOwned::new_from_bytes::<i64>(self.n(), cols, bytes)
|
||||||
}
|
}
|
||||||
// fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar {
|
|
||||||
// Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes)
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// impl<B: Backend> ZnxAlloc<B> for Scalar {
|
pub trait ScalarToRef {
|
||||||
// type Scalar = i64;
|
fn to_ref(&self) -> Scalar<&[u8]>;
|
||||||
|
}
|
||||||
|
|
||||||
// fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
|
pub trait ScalarToMut {
|
||||||
// Self {
|
fn to_mut(&mut self) -> Scalar<&mut [u8]>;
|
||||||
// inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes),
|
}
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, _size: usize) -> usize {
|
impl ScalarToMut for Scalar<Vec<u8>> {
|
||||||
// debug_assert_eq!(
|
fn to_mut(&mut self) -> Scalar<&mut [u8]> {
|
||||||
// _rows, SCALAR_ZNX_ROWS,
|
Scalar {
|
||||||
// "rows != {} not supported for Scalar",
|
data: self.data.as_mut_slice(),
|
||||||
// SCALAR_ZNX_ROWS
|
n: self.n,
|
||||||
// );
|
cols: self.cols,
|
||||||
// debug_assert_eq!(
|
}
|
||||||
// _size, SCALAR_ZNX_SIZE,
|
}
|
||||||
// "rows != {} not supported for Scalar",
|
}
|
||||||
// SCALAR_ZNX_SIZE
|
|
||||||
// );
|
impl ScalarToRef for Scalar<Vec<u8>> {
|
||||||
// module.n() * cols * std::mem::size_of::<self::Scalar>()
|
fn to_ref(&self) -> Scalar<&[u8]> {
|
||||||
// }
|
Scalar {
|
||||||
// }
|
data: self.data.as_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScalarToMut for Scalar<&mut [u8]> {
|
||||||
|
fn to_mut(&mut self) -> Scalar<&mut [u8]> {
|
||||||
|
Scalar {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScalarToRef for Scalar<&mut [u8]> {
|
||||||
|
fn to_ref(&self) -> Scalar<&[u8]> {
|
||||||
|
Scalar {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScalarToRef for Scalar<&[u8]> {
|
||||||
|
fn to_ref(&self) -> Scalar<&[u8]> {
|
||||||
|
Scalar {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,19 +2,16 @@ use std::marker::PhantomData;
|
|||||||
|
|
||||||
use crate::ffi::svp;
|
use crate::ffi::svp;
|
||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned};
|
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
|
||||||
|
|
||||||
pub const SCALAR_ZNX_DFT_ROWS: usize = 1;
|
pub struct ScalarZnxDft<D, B: Backend> {
|
||||||
pub const SCALAR_ZNX_DFT_SIZE: usize = 1;
|
|
||||||
|
|
||||||
pub struct ScalarZnxDft<D, B> {
|
|
||||||
data: D,
|
data: D,
|
||||||
n: usize,
|
n: usize,
|
||||||
cols: usize,
|
cols: usize,
|
||||||
_phantom: PhantomData<B>,
|
_phantom: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> ZnxInfos for ScalarZnxDft<D, B> {
|
impl<D, B: Backend> ZnxInfos for ScalarZnxDft<D, B> {
|
||||||
fn cols(&self) -> usize {
|
fn cols(&self) -> usize {
|
||||||
self.cols
|
self.cols
|
||||||
}
|
}
|
||||||
@@ -30,20 +27,22 @@ impl<D, B> ZnxInfos for ScalarZnxDft<D, B> {
|
|||||||
fn size(&self) -> usize {
|
fn size(&self) -> usize {
|
||||||
1
|
1
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D> ZnxSliceSize for ScalarZnxDft<D, FFT64> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.n()
|
self.n()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> DataView for ScalarZnxDft<D, B> {
|
impl<D, B: Backend> DataView for ScalarZnxDft<D, B> {
|
||||||
type D = D;
|
type D = D;
|
||||||
fn data(&self) -> &Self::D {
|
fn data(&self) -> &Self::D {
|
||||||
&self.data
|
&self.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> DataViewMut for ScalarZnxDft<D, B> {
|
impl<D, B: Backend> DataViewMut for ScalarZnxDft<D, B> {
|
||||||
fn data_mut(&mut self) -> &mut Self::D {
|
fn data_mut(&mut self) -> &mut Self::D {
|
||||||
&mut self.data
|
&mut self.data
|
||||||
}
|
}
|
||||||
@@ -78,20 +77,69 @@ impl<D: From<Vec<u8>>, B: Backend> ScalarZnxDft<D, B> {
|
|||||||
_phantom: PhantomData,
|
_phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
|
|
||||||
// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, _size));
|
|
||||||
// Self {
|
|
||||||
// inner: ZnxBase::from_bytes_borrow(
|
|
||||||
// module.n(),
|
|
||||||
// SCALAR_ZNX_DFT_ROWS,
|
|
||||||
// cols,
|
|
||||||
// SCALAR_ZNX_DFT_SIZE,
|
|
||||||
// bytes,
|
|
||||||
// ),
|
|
||||||
// _phantom: PhantomData,
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
|
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
|
||||||
|
|
||||||
|
pub trait ScalarZnxDftToRef<B: Backend> {
|
||||||
|
fn to_ref(&self) -> ScalarZnxDft<&[u8], B>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ScalarZnxDftToMut<B: Backend> {
|
||||||
|
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
|
||||||
|
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
|
||||||
|
ScalarZnxDft {
|
||||||
|
data: self.data.as_mut_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
|
||||||
|
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||||
|
ScalarZnxDft {
|
||||||
|
data: self.data.as_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
|
||||||
|
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
|
||||||
|
ScalarZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
|
||||||
|
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||||
|
ScalarZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
|
||||||
|
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||||
|
ScalarZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,26 +1,28 @@
|
|||||||
use crate::ffi::svp::{self, svp_ppol_t};
|
use crate::ffi::svp;
|
||||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||||
use crate::{Backend, FFT64, Module, Scalar, ScalarZnxDft, ScalarZnxDftOwned, VecZnx, VecZnxDft};
|
use crate::{
|
||||||
|
Backend, FFT64, Module, ScalarToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, VecZnx,
|
||||||
|
VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxSliceSize,
|
||||||
|
};
|
||||||
|
|
||||||
pub trait ScalarZnxDftAlloc<B> {
|
pub trait ScalarZnxDftAlloc<B: Backend> {
|
||||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
|
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
|
||||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
|
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
|
||||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
|
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
|
||||||
// fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft<B>;
|
// fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft<B>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ScalarZnxDftOps<DataMut, Data, B: Backend> {
|
pub trait ScalarZnxDftOps<BACKEND: Backend> {
|
||||||
fn svp_prepare(&self, res: &mut ScalarZnxDft<DataMut, B>, res_col: usize, a: &Scalar<Data>, a_col: usize);
|
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
fn svp_apply_dft(
|
where
|
||||||
&self,
|
R: ScalarZnxDftToMut<BACKEND>,
|
||||||
res: &mut VecZnxDft<DataMut, B>,
|
A: ScalarToRef;
|
||||||
res_col: usize,
|
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
a: &ScalarZnxDft<Data, B>,
|
where
|
||||||
a_col: usize,
|
R: VecZnxDftToMut<BACKEND>,
|
||||||
b: &VecZnx<Data>,
|
A: ScalarZnxDftToRef<BACKEND>,
|
||||||
b_col: usize,
|
B: VecZnxToRef;
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
|
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
|
||||||
@@ -35,42 +37,38 @@ impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
|
|||||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
|
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
|
||||||
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
|
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
|
|
||||||
// ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes)
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<DataMut, Data> ScalarZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
|
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||||
where
|
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
DataMut: AsMut<[u8]> + AsRef<[u8]>,
|
where
|
||||||
Data: AsRef<[u8]>,
|
R: ScalarZnxDftToMut<FFT64>,
|
||||||
{
|
A: ScalarToRef,
|
||||||
fn svp_prepare(&self, res: &mut ScalarZnxDft<DataMut, FFT64>, res_col: usize, a: &Scalar<Data>, a_col: usize) {
|
{
|
||||||
unsafe {
|
unsafe {
|
||||||
svp::svp_prepare(
|
svp::svp_prepare(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col, 0) as *mut svp_ppol_t,
|
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||||
a.at_ptr(a_col, 0),
|
a.to_ref().at_ptr(a_col, 0),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn svp_apply_dft(
|
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxDft<DataMut, FFT64>,
|
R: VecZnxDftToMut<FFT64>,
|
||||||
res_col: usize,
|
A: ScalarZnxDftToRef<FFT64>,
|
||||||
a: &ScalarZnxDft<Data, FFT64>,
|
B: VecZnxToRef,
|
||||||
a_col: usize,
|
{
|
||||||
b: &VecZnx<Data>,
|
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||||
b_col: usize,
|
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||||
) {
|
let b: VecZnx<&[u8]> = b.to_ref();
|
||||||
unsafe {
|
unsafe {
|
||||||
svp::svp_apply_dft(
|
svp::svp_apply_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
a.at_ptr(a_col, 0) as *const svp_ppol_t,
|
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||||
b.at_ptr(b_col, 0),
|
b.at_ptr(b_col, 0),
|
||||||
b.size() as u64,
|
b.size() as u64,
|
||||||
b.sl() as u64,
|
b.sl() as u64,
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
use crate::DataView;
|
use crate::DataView;
|
||||||
use crate::DataViewMut;
|
use crate::DataViewMut;
|
||||||
|
use crate::ZnxSliceSize;
|
||||||
use crate::alloc_aligned;
|
use crate::alloc_aligned;
|
||||||
use crate::assert_alignement;
|
use crate::assert_alignement;
|
||||||
use crate::cast_mut;
|
use crate::cast_mut;
|
||||||
use crate::ffi::znx;
|
use crate::ffi::znx;
|
||||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut, switch_degree};
|
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||||
use std::{cmp::min, fmt};
|
use std::{cmp::min, fmt};
|
||||||
|
|
||||||
// pub const VEC_ZNX_ROWS: usize = 1;
|
|
||||||
|
|
||||||
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
|
/// [VecZnx] represents collection of contiguously stacked vector of small norm polynomials of
|
||||||
/// Zn\[X\] with [i64] coefficients.
|
/// Zn\[X\] with [i64] coefficients.
|
||||||
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
|
/// A [VecZnx] is composed of multiple Zn\[X\] polynomials stored in a single contiguous array
|
||||||
@@ -20,7 +19,7 @@ use std::{cmp::min, fmt};
|
|||||||
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
|
/// layout is: `[a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3]`, where ai, bi, ci
|
||||||
/// are small polynomials of Zn\[X\].
|
/// are small polynomials of Zn\[X\].
|
||||||
pub struct VecZnx<D> {
|
pub struct VecZnx<D> {
|
||||||
data: D,
|
pub data: D,
|
||||||
n: usize,
|
n: usize,
|
||||||
cols: usize,
|
cols: usize,
|
||||||
size: usize,
|
size: usize,
|
||||||
@@ -42,9 +41,11 @@ impl<D> ZnxInfos for VecZnx<D> {
|
|||||||
fn size(&self) -> usize {
|
fn size(&self) -> usize {
|
||||||
self.size
|
self.size
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D> ZnxSliceSize for VecZnx<D> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.cols() * self.n()
|
self.n() * self.cols()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,10 +67,6 @@ impl<D: AsRef<[u8]>> ZnxView for VecZnx<D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
|
impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
|
||||||
pub fn normalize(&mut self, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
|
||||||
normalize(log_base2k, self, col, carry)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Truncates the precision of the [VecZnx] by k bits.
|
/// Truncates the precision of the [VecZnx] by k bits.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -92,11 +89,6 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
|
|||||||
.for_each(|x: &mut i64| *x &= mask)
|
.for_each(|x: &mut i64| *x &= mask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Switches degree of from `a.n()` to `self.n()` into `self`
|
|
||||||
pub fn switch_degree<Data: AsRef<[u8]>>(&mut self, col: usize, a: &VecZnx<Data>, col_a: usize) {
|
|
||||||
switch_degree(self, col_a, a, col)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: From<Vec<u8>>> VecZnx<D> {
|
impl<D: From<Vec<u8>>> VecZnx<D> {
|
||||||
@@ -126,6 +118,17 @@ impl<D: From<Vec<u8>>> VecZnx<D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D> VecZnx<D> {
|
||||||
|
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
n,
|
||||||
|
cols,
|
||||||
|
size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Copies the coefficients of `a` on the receiver.
|
/// Copies the coefficients of `a` on the receiver.
|
||||||
/// Copy is done with the minimum size matching both backing arrays.
|
/// Copy is done with the minimum size matching both backing arrays.
|
||||||
/// Panics if the cols do not match.
|
/// Panics if the cols do not match.
|
||||||
@@ -141,10 +144,12 @@ where
|
|||||||
data_b[..size].copy_from_slice(&data_a[..size])
|
data_b[..size].copy_from_slice(&data_a[..size])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
fn normalize_tmp_bytes(n: usize) -> usize {
|
fn normalize_tmp_bytes(n: usize) -> usize {
|
||||||
n * std::mem::size_of::<i64>()
|
n * std::mem::size_of::<i64>()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx<D>, a_col: usize, tmp_bytes: &mut [u8]) {
|
fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx<D>, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||||
let n: usize = a.n();
|
let n: usize = a.n();
|
||||||
|
|
||||||
@@ -216,8 +221,16 @@ pub type VecZnxOwned = VecZnx<Vec<u8>>;
|
|||||||
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
||||||
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
||||||
|
|
||||||
impl VecZnx<Vec<u8>> {
|
pub trait VecZnxToRef {
|
||||||
pub fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
fn to_ref(&self) -> VecZnx<&[u8]>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait VecZnxToMut {
|
||||||
|
fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VecZnxToMut for VecZnx<Vec<u8>> {
|
||||||
|
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||||
VecZnx {
|
VecZnx {
|
||||||
data: self.data.as_mut_slice(),
|
data: self.data.as_mut_slice(),
|
||||||
n: self.n,
|
n: self.n,
|
||||||
@@ -225,8 +238,10 @@ impl VecZnx<Vec<u8>> {
|
|||||||
size: self.size,
|
size: self.size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_ref(&self) -> VecZnx<&[u8]> {
|
impl VecZnxToRef for VecZnx<Vec<u8>> {
|
||||||
|
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||||
VecZnx {
|
VecZnx {
|
||||||
data: self.data.as_slice(),
|
data: self.data.as_slice(),
|
||||||
n: self.n,
|
n: self.n,
|
||||||
@@ -236,10 +251,32 @@ impl VecZnx<Vec<u8>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VecZnx<&mut [u8]> {
|
impl VecZnxToMut for VecZnx<&mut [u8]> {
|
||||||
pub fn to_ref(&self) -> VecZnx<&[u8]> {
|
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||||
VecZnx {
|
VecZnx {
|
||||||
data: &self.data,
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VecZnxToRef for VecZnx<&mut [u8]> {
|
||||||
|
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||||
|
VecZnx {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VecZnxToRef for VecZnx<&[u8]> {
|
||||||
|
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||||
|
VecZnx {
|
||||||
|
data: self.data,
|
||||||
n: self.n,
|
n: self.n,
|
||||||
cols: self.cols,
|
cols: self.cols,
|
||||||
size: self.size,
|
size: self.size,
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
use crate::ffi::vec_znx_big;
|
use crate::ffi::vec_znx_big;
|
||||||
use crate::znx_base::{ZnxInfos, ZnxView};
|
use crate::znx_base::{ZnxInfos, ZnxView};
|
||||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned};
|
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
// const VEC_ZNX_BIG_ROWS: usize = 1;
|
pub struct VecZnxBig<D, B: Backend> {
|
||||||
|
|
||||||
/// VecZnxBig is `Backend` dependent, denoted with backend generic `B`
|
|
||||||
pub struct VecZnxBig<D, B> {
|
|
||||||
data: D,
|
data: D,
|
||||||
n: usize,
|
n: usize,
|
||||||
cols: usize,
|
cols: usize,
|
||||||
@@ -14,7 +11,7 @@ pub struct VecZnxBig<D, B> {
|
|||||||
_phantom: PhantomData<B>,
|
_phantom: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> ZnxInfos for VecZnxBig<D, B> {
|
impl<D, B: Backend> ZnxInfos for VecZnxBig<D, B> {
|
||||||
fn cols(&self) -> usize {
|
fn cols(&self) -> usize {
|
||||||
self.cols
|
self.cols
|
||||||
}
|
}
|
||||||
@@ -30,20 +27,22 @@ impl<D, B> ZnxInfos for VecZnxBig<D, B> {
|
|||||||
fn size(&self) -> usize {
|
fn size(&self) -> usize {
|
||||||
self.size
|
self.size
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D> ZnxSliceSize for VecZnxBig<D, FFT64> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.cols() * self.n()
|
self.n() * self.cols()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> DataView for VecZnxBig<D, B> {
|
impl<D, B: Backend> DataView for VecZnxBig<D, B> {
|
||||||
type D = D;
|
type D = D;
|
||||||
fn data(&self) -> &Self::D {
|
fn data(&self) -> &Self::D {
|
||||||
&self.data
|
&self.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> DataViewMut for VecZnxBig<D, B> {
|
impl<D, B: Backend> DataViewMut for VecZnxBig<D, B> {
|
||||||
fn data_mut(&mut self) -> &mut Self::D {
|
fn data_mut(&mut self) -> &mut Self::D {
|
||||||
&mut self.data
|
&mut self.data
|
||||||
}
|
}
|
||||||
@@ -82,7 +81,7 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> VecZnxBig<D, B> {
|
impl<D, B: Backend> VecZnxBig<D, B> {
|
||||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data,
|
data,
|
||||||
@@ -96,8 +95,16 @@ impl<D, B> VecZnxBig<D, B> {
|
|||||||
|
|
||||||
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
||||||
|
|
||||||
impl<B> VecZnxBig<Vec<u8>, B> {
|
pub trait VecZnxBigToRef<B: Backend> {
|
||||||
pub fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait VecZnxBigToMut<B: Backend> {
|
||||||
|
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<Vec<u8>, B> {
|
||||||
|
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||||
VecZnxBig {
|
VecZnxBig {
|
||||||
data: self.data.as_mut_slice(),
|
data: self.data.as_mut_slice(),
|
||||||
n: self.n,
|
n: self.n,
|
||||||
@@ -106,8 +113,10 @@ impl<B> VecZnxBig<Vec<u8>, B> {
|
|||||||
_phantom: PhantomData,
|
_phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<Vec<u8>, B> {
|
||||||
|
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||||
VecZnxBig {
|
VecZnxBig {
|
||||||
data: self.data.as_slice(),
|
data: self.data.as_slice(),
|
||||||
n: self.n,
|
n: self.n,
|
||||||
@@ -117,3 +126,39 @@ impl<B> VecZnxBig<Vec<u8>, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<&mut [u8], B> {
|
||||||
|
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||||
|
VecZnxBig {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&mut [u8], B> {
|
||||||
|
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||||
|
VecZnxBig {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&[u8], B> {
|
||||||
|
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||||
|
VecZnxBig {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
use crate::ffi::vec_znx;
|
use crate::ffi::vec_znx;
|
||||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||||
use crate::{Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, bytes_of_vec_znx_big};
|
use crate::{
|
||||||
|
Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch,
|
||||||
|
VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big,
|
||||||
|
};
|
||||||
|
|
||||||
pub trait VecZnxBigAlloc<B> {
|
pub trait VecZnxBigAlloc<B: Backend> {
|
||||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||||
|
|
||||||
@@ -39,79 +42,77 @@ pub trait VecZnxBigAlloc<B> {
|
|||||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
|
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait VecZnxBigOps<DataMut, Data, B> {
|
pub trait VecZnxBigOps<BACKEND: Backend> {
|
||||||
/// Adds `a` to `b` and stores the result on `c`.
|
/// Adds `a` to `b` and stores the result on `c`.
|
||||||
fn vec_znx_big_add(
|
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, B>,
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<BACKEND>,
|
||||||
a: &VecZnxBig<Data, B>,
|
B: VecZnxBigToRef<BACKEND>;
|
||||||
a_col: usize,
|
|
||||||
b: &VecZnxBig<Data, B>,
|
|
||||||
b_col: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Adds `a` to `b` and stores the result on `b`.
|
/// Adds `a` to `b` and stores the result on `b`.
|
||||||
fn vec_znx_big_add_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnxBig<Data, B>, a_col: usize);
|
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
|
A: VecZnxBigToRef<BACKEND>;
|
||||||
|
|
||||||
/// Adds `a` to `b` and stores the result on `c`.
|
/// Adds `a` to `b` and stores the result on `c`.
|
||||||
fn vec_znx_big_add_small(
|
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, B>,
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<BACKEND>,
|
||||||
a: &VecZnxBig<Data, B>,
|
B: VecZnxToRef;
|
||||||
a_col: usize,
|
|
||||||
b: &VecZnx<Data>,
|
|
||||||
b_col: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Adds `a` to `b` and stores the result on `b`.
|
/// Adds `a` to `b` and stores the result on `b`.
|
||||||
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||||
fn vec_znx_big_sub(
|
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, B>,
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<BACKEND>,
|
||||||
a: &VecZnxBig<Data, B>,
|
B: VecZnxBigToRef<BACKEND>;
|
||||||
a_col: usize,
|
|
||||||
b: &VecZnxBig<Data, B>,
|
|
||||||
b_col: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||||
fn vec_znx_big_sub_ab_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnxBig<Data, B>, a_col: usize);
|
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
|
A: VecZnxBigToRef<BACKEND>;
|
||||||
|
|
||||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||||
fn vec_znx_big_sub_ba_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnxBig<Data, B>, a_col: usize);
|
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
|
A: VecZnxBigToRef<BACKEND>;
|
||||||
|
|
||||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||||
fn vec_znx_big_sub_small_a(
|
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, B>,
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
res_col: usize,
|
A: VecZnxToRef,
|
||||||
a: &VecZnx<Data>,
|
B: VecZnxBigToRef<BACKEND>;
|
||||||
a_col: usize,
|
|
||||||
b: &VecZnxBig<Data, B>,
|
|
||||||
b_col: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||||
fn vec_znx_big_sub_small_a_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||||
fn vec_znx_big_sub_small_b(
|
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, B>,
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<BACKEND>,
|
||||||
a: &VecZnxBig<Data, B>,
|
B: VecZnxToRef;
|
||||||
a_col: usize,
|
|
||||||
b: &VecZnx<Data>,
|
|
||||||
b_col: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||||
fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
/// Normalizes `a` and stores the result on `b`.
|
/// Normalizes `a` and stores the result on `b`.
|
||||||
///
|
///
|
||||||
@@ -119,28 +120,28 @@ pub trait VecZnxBigOps<DataMut, Data, B> {
|
|||||||
///
|
///
|
||||||
/// * `log_base2k`: normalization basis.
|
/// * `log_base2k`: normalization basis.
|
||||||
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
|
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
|
||||||
fn vec_znx_big_normalize(
|
fn vec_znx_big_normalize<R, A>(
|
||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
res: &mut VecZnx<DataMut>,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &VecZnxBig<Data, B>,
|
a: &A,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
scratch: &mut Scratch,
|
scratch: &mut Scratch,
|
||||||
);
|
) where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxBigToRef<BACKEND>;
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||||
fn vec_znx_big_automorphism(
|
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
&self,
|
where
|
||||||
k: i64,
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
res: &mut VecZnxBig<DataMut, B>,
|
A: VecZnxBigToRef<BACKEND>;
|
||||||
res_col: usize,
|
|
||||||
a: &VecZnxBig<Data, B>,
|
|
||||||
a_col: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||||
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<DataMut, B>, a_col: usize);
|
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||||
|
where
|
||||||
|
A: VecZnxBigToMut<BACKEND>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait VecZnxBigScratch {
|
pub trait VecZnxBigScratch {
|
||||||
@@ -157,29 +158,22 @@ impl VecZnxBigAlloc<FFT64> for Module<FFT64> {
|
|||||||
VecZnxBig::new_from_bytes(self, cols, size, bytes)
|
VecZnxBig::new_from_bytes(self, cols, size, bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<FFT64> {
|
|
||||||
// VecZnxBig::from_bytes_borrow(self, 1, cols, size, tmp_bytes)
|
|
||||||
// }
|
|
||||||
|
|
||||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
||||||
bytes_of_vec_znx_big(self, cols, size)
|
bytes_of_vec_znx_big(self, cols, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<DataMut, Data> VecZnxBigOps<DataMut, Data, FFT64> for Module<FFT64>
|
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||||
where
|
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
DataMut: AsMut<[u8]> + AsRef<[u8]>,
|
where
|
||||||
Data: AsRef<[u8]>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
{
|
A: VecZnxBigToRef<FFT64>,
|
||||||
fn vec_znx_big_add(
|
B: VecZnxBigToRef<FFT64>,
|
||||||
&self,
|
{
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||||
res_col: usize,
|
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
a_col: usize,
|
|
||||||
b: &VecZnxBig<Data, FFT64>,
|
|
||||||
b_col: usize,
|
|
||||||
) {
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -203,13 +197,14 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_add_inplace(
|
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<FFT64>,
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
{
|
||||||
a_col: usize,
|
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||||
) {
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -231,15 +226,16 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_sub(
|
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<FFT64>,
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
B: VecZnxBigToRef<FFT64>,
|
||||||
a_col: usize,
|
{
|
||||||
b: &VecZnxBig<Data, FFT64>,
|
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||||
b_col: usize,
|
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||||
) {
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -263,13 +259,14 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_sub_ab_inplace(
|
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<FFT64>,
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
{
|
||||||
a_col: usize,
|
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||||
) {
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -291,13 +288,14 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_sub_ba_inplace(
|
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<FFT64>,
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
{
|
||||||
a_col: usize,
|
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||||
) {
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -319,15 +317,16 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_sub_small_b(
|
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<FFT64>,
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
B: VecZnxToRef,
|
||||||
a_col: usize,
|
{
|
||||||
b: &VecZnx<Data>,
|
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||||
b_col: usize,
|
let b: VecZnx<&[u8]> = b.to_ref();
|
||||||
) {
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -351,13 +350,14 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_sub_small_b_inplace(
|
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxToRef,
|
||||||
a: &VecZnx<Data>,
|
{
|
||||||
a_col: usize,
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
) {
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -379,15 +379,16 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_sub_small_a(
|
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxToRef,
|
||||||
a: &VecZnx<Data>,
|
B: VecZnxBigToRef<FFT64>,
|
||||||
a_col: usize,
|
{
|
||||||
b: &VecZnxBig<Data, FFT64>,
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
b_col: usize,
|
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||||
) {
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -411,13 +412,14 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_sub_small_a_inplace(
|
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxToRef,
|
||||||
a: &VecZnx<Data>,
|
{
|
||||||
a_col: usize,
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
) {
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -439,15 +441,16 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_add_small(
|
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxBigToRef<FFT64>,
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
B: VecZnxToRef,
|
||||||
a_col: usize,
|
{
|
||||||
b: &VecZnx<Data>,
|
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||||
b_col: usize,
|
let b: VecZnx<&[u8]> = b.to_ref();
|
||||||
) {
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -471,7 +474,14 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_add_small_inplace(&self, res: &mut VecZnxBig<DataMut, FFT64>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
|
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxBigToMut<FFT64>,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -493,22 +503,28 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_normalize(
|
fn vec_znx_big_normalize<R, A>(
|
||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
res: &mut VecZnx<DataMut>,
|
res: &mut R,
|
||||||
res_col: usize,
|
res_col: usize,
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
a: &A,
|
||||||
a_col: usize,
|
a_col: usize,
|
||||||
scratch: &mut Scratch,
|
scratch: &mut Scratch,
|
||||||
) {
|
) where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxBigToRef<FFT64>,
|
||||||
|
{
|
||||||
|
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||||
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
assert_eq!(res.n(), self.n());
|
assert_eq!(res.n(), self.n());
|
||||||
//(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes.
|
//(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes.
|
||||||
// In the FFT backend the tmp sizes are same but will be different in the NTT backend
|
// In the FFT backend the tmp sizes are same but will be different in the NTT backend
|
||||||
// assert!(tmp_bytes.len() >= <Self as VecZnxOps<DataMut, Data>>::vec_znx_normalize_tmp_bytes(&self));
|
// assert!(tmp_bytes.len() >= <Self as VecZnxOps<&mut [u8], & [u8]>>::vec_znx_normalize_tmp_bytes(&self));
|
||||||
// assert_alignement(tmp_bytes.as_ptr());
|
// assert_alignement(tmp_bytes.as_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -530,14 +546,14 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_automorphism(
|
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
&self,
|
where
|
||||||
k: i64,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
A: VecZnxBigToRef<FFT64>,
|
||||||
res_col: usize,
|
{
|
||||||
a: &VecZnxBig<Data, FFT64>,
|
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||||
a_col: usize,
|
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||||
) {
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -557,7 +573,12 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig<DataMut, FFT64>, a_col: usize) {
|
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||||
|
where
|
||||||
|
A: VecZnxBigToMut<FFT64>,
|
||||||
|
{
|
||||||
|
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
|
|||||||
@@ -2,12 +2,9 @@ use std::marker::PhantomData;
|
|||||||
|
|
||||||
use crate::ffi::vec_znx_dft;
|
use crate::ffi::vec_znx_dft;
|
||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned};
|
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
|
||||||
|
|
||||||
// const VEC_ZNX_DFT_ROWS: usize = 1;
|
pub struct VecZnxDft<D, B: Backend> {
|
||||||
|
|
||||||
// VecZnxDft is `Backend` dependent denoted with generic `B`
|
|
||||||
pub struct VecZnxDft<D, B> {
|
|
||||||
data: D,
|
data: D,
|
||||||
n: usize,
|
n: usize,
|
||||||
cols: usize,
|
cols: usize,
|
||||||
@@ -15,7 +12,7 @@ pub struct VecZnxDft<D, B> {
|
|||||||
_phantom: PhantomData<B>,
|
_phantom: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> ZnxInfos for VecZnxDft<D, B> {
|
impl<D, B: Backend> ZnxInfos for VecZnxDft<D, B> {
|
||||||
fn cols(&self) -> usize {
|
fn cols(&self) -> usize {
|
||||||
self.cols
|
self.cols
|
||||||
}
|
}
|
||||||
@@ -31,20 +28,22 @@ impl<D, B> ZnxInfos for VecZnxDft<D, B> {
|
|||||||
fn size(&self) -> usize {
|
fn size(&self) -> usize {
|
||||||
self.size
|
self.size
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D> ZnxSliceSize for VecZnxDft<D, FFT64> {
|
||||||
fn sl(&self) -> usize {
|
fn sl(&self) -> usize {
|
||||||
self.cols() * self.n()
|
self.n() * self.cols()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> DataView for VecZnxDft<D, B> {
|
impl<D, B: Backend> DataView for VecZnxDft<D, B> {
|
||||||
type D = D;
|
type D = D;
|
||||||
fn data(&self) -> &Self::D {
|
fn data(&self) -> &Self::D {
|
||||||
&self.data
|
&self.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B> DataViewMut for VecZnxDft<D, B> {
|
impl<D, B: Backend> DataViewMut for VecZnxDft<D, B> {
|
||||||
fn data_mut(&mut self) -> &mut Self::D {
|
fn data_mut(&mut self) -> &mut Self::D {
|
||||||
&mut self.data
|
&mut self.data
|
||||||
}
|
}
|
||||||
@@ -85,7 +84,7 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
|
|||||||
|
|
||||||
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
|
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
|
||||||
|
|
||||||
impl<D, B> VecZnxDft<D, B> {
|
impl<D, B: Backend> VecZnxDft<D, B> {
|
||||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data,
|
data,
|
||||||
@@ -97,8 +96,16 @@ impl<D, B> VecZnxDft<D, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B> VecZnxDft<Vec<u8>, B> {
|
pub trait VecZnxDftToRef<B: Backend> {
|
||||||
pub fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
fn to_ref(&self) -> VecZnxDft<&[u8], B>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait VecZnxDftToMut<B: Backend> {
|
||||||
|
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<Vec<u8>, B> {
|
||||||
|
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||||
VecZnxDft {
|
VecZnxDft {
|
||||||
data: self.data.as_mut_slice(),
|
data: self.data.as_mut_slice(),
|
||||||
n: self.n,
|
n: self.n,
|
||||||
@@ -107,8 +114,10 @@ impl<B> VecZnxDft<Vec<u8>, B> {
|
|||||||
_phantom: PhantomData,
|
_phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<Vec<u8>, B> {
|
||||||
|
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||||
VecZnxDft {
|
VecZnxDft {
|
||||||
data: self.data.as_slice(),
|
data: self.data.as_slice(),
|
||||||
n: self.n,
|
n: self.n,
|
||||||
@@ -119,10 +128,34 @@ impl<B> VecZnxDft<Vec<u8>, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B> VecZnxDft<&mut [u8], B> {
|
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<&mut [u8], B> {
|
||||||
pub fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||||
VecZnxDft {
|
VecZnxDft {
|
||||||
data: &self.data,
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&mut [u8], B> {
|
||||||
|
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||||
|
VecZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&[u8], B> {
|
||||||
|
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||||
|
VecZnxDft {
|
||||||
|
data: self.data,
|
||||||
n: self.n,
|
n: self.n,
|
||||||
cols: self.cols,
|
cols: self.cols,
|
||||||
size: self.size,
|
size: self.size,
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use crate::ffi::{vec_znx_big, vec_znx_dft};
|
use crate::ffi::{vec_znx_big, vec_znx_dft};
|
||||||
use crate::vec_znx_dft::bytes_of_vec_znx_dft;
|
use crate::vec_znx_dft::bytes_of_vec_znx_dft;
|
||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::{Backend, VecZnxDftOwned};
|
use crate::{Backend, Scratch, VecZnxBigToMut, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxSliceSize};
|
||||||
use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement};
|
use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
|
||||||
pub trait VecZnxDftAlloc<B> {
|
pub trait VecZnxDftAlloc<B: Backend> {
|
||||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||||
|
|
||||||
@@ -34,24 +34,26 @@ pub trait VecZnxDftAlloc<B> {
|
|||||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
|
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait VecZnxDftOps<DataMut, Data, B> {
|
pub trait VecZnxDftOps<B: Backend> {
|
||||||
/// Returns the minimum number of bytes necessary to allocate
|
/// Returns the minimum number of bytes necessary to allocate
|
||||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||||
|
|
||||||
/// b <- IDFT(a), uses a as scratch space.
|
/// b <- IDFT(a), uses a as scratch space.
|
||||||
fn vec_znx_idft_tmp_a(&self, res: &mut VecZnxBig<DataMut, B>, res_col: usize, a: &mut VecZnxDft<DataMut, B>, a_cols: usize);
|
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_cols: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxBigToMut<B>,
|
||||||
|
A: VecZnxDftToMut<B>;
|
||||||
|
|
||||||
fn vec_znx_idft(
|
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, B>,
|
R: VecZnxBigToMut<B>,
|
||||||
res_col: usize,
|
A: VecZnxDftToRef<B>;
|
||||||
a: &VecZnxDft<Data, B>,
|
|
||||||
a_col: usize,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
);
|
|
||||||
|
|
||||||
fn vec_znx_dft(&self, res: &mut VecZnxDft<DataMut, B>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxDftToMut<B>,
|
||||||
|
A: VecZnxToRef;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
|
impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
|
||||||
@@ -63,41 +65,34 @@ impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
|
|||||||
VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
|
VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<FFT64> {
|
|
||||||
// VecZnxDft::from_bytes_borrow(self, 1, cols, size, bytes)
|
|
||||||
// }
|
|
||||||
|
|
||||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
||||||
bytes_of_vec_znx_dft(self, cols, size)
|
bytes_of_vec_znx_dft(self, cols, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<DataMut, Data> VecZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
|
impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||||
where
|
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||||
DataMut: AsMut<[u8]> + AsRef<[u8]>,
|
where
|
||||||
Data: AsRef<[u8]>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
{
|
A: VecZnxDftToMut<FFT64>,
|
||||||
fn vec_znx_idft_tmp_a(
|
{
|
||||||
&self,
|
let mut res_mut = res.to_mut();
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
let mut a_mut = a.to_mut();
|
||||||
res_col: usize,
|
|
||||||
a: &mut VecZnxDft<DataMut, FFT64>,
|
let min_size: usize = min(res_mut.size(), a_mut.size());
|
||||||
a_col: usize,
|
|
||||||
) {
|
|
||||||
let min_size: usize = min(res.size(), a.size());
|
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
(0..min_size).for_each(|j| {
|
(0..min_size).for_each(|j| {
|
||||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||||
1 as u64,
|
1 as u64,
|
||||||
a.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||||
1 as u64,
|
1 as u64,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
(min_size..res.size()).for_each(|j| {
|
(min_size..res_mut.size()).for_each(|j| {
|
||||||
res.zero_at(res_col, j);
|
res_mut.zero_at(res_col, j);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -110,61 +105,59 @@ where
|
|||||||
///
|
///
|
||||||
/// # Panics
|
/// # Panics
|
||||||
/// If b.cols < a_cols
|
/// If b.cols < a_cols
|
||||||
fn vec_znx_dft(&self, res: &mut VecZnxDft<DataMut, FFT64>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
|
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
let min_size: usize = min(res.size(), a.size());
|
where
|
||||||
|
R: VecZnxDftToMut<FFT64>,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let mut res_mut = res.to_mut();
|
||||||
|
let a_ref = a.to_ref();
|
||||||
|
|
||||||
|
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
(0..min_size).for_each(|j| {
|
(0..min_size).for_each(|j| {
|
||||||
vec_znx_dft::vec_znx_dft(
|
vec_znx_dft::vec_znx_dft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||||
1 as u64,
|
1 as u64,
|
||||||
a.at_ptr(a_col, j),
|
a_ref.at_ptr(a_col, j),
|
||||||
1 as u64,
|
1 as u64,
|
||||||
a.sl() as u64,
|
a_ref.sl() as u64,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
(min_size..res.size()).for_each(|j| {
|
(min_size..res_mut.size()).for_each(|j| {
|
||||||
res.zero_at(res_col, j);
|
res_mut.zero_at(res_col, j);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
||||||
fn vec_znx_idft(
|
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnxBig<DataMut, FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
res_col: usize,
|
A: VecZnxDftToRef<FFT64>,
|
||||||
a: &VecZnxDft<Data, FFT64>,
|
|
||||||
a_col: usize,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
{
|
||||||
assert!(
|
let mut res_mut = res.to_mut();
|
||||||
tmp_bytes.len() >= <Self as VecZnxDftOps<DataMut, DataMut, FFT64>>::vec_znx_idft_tmp_bytes(self),
|
let a_ref = a.to_ref();
|
||||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
|
|
||||||
tmp_bytes.len(),
|
|
||||||
<Self as VecZnxDftOps<DataMut, DataMut, FFT64>>::vec_znx_idft_tmp_bytes(self)
|
|
||||||
);
|
|
||||||
assert_alignement(tmp_bytes.as_ptr())
|
|
||||||
}
|
|
||||||
|
|
||||||
let min_size: usize = min(res.size(), a.size());
|
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_idft_tmp_bytes());
|
||||||
|
|
||||||
|
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
(0..min_size).for_each(|j| {
|
(0..min_size).for_each(|j| {
|
||||||
vec_znx_dft::vec_znx_idft(
|
vec_znx_dft::vec_znx_idft(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||||
1 as u64,
|
1 as u64,
|
||||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||||
1 as u64,
|
1 as u64,
|
||||||
tmp_bytes.as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
(min_size..res.size()).for_each(|j| {
|
(min_size..res_mut.size()).for_each(|j| {
|
||||||
res.zero_at(res_col, j);
|
res_mut.zero_at(res_col, j);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
use crate::ffi::vec_znx;
|
use crate::ffi::vec_znx;
|
||||||
use crate::znx_base::{ZnxInfos, switch_degree};
|
use crate::{
|
||||||
use crate::{Backend, Module, VecZnx, VecZnxOwned, ZnxView, ZnxViewMut, assert_alignement};
|
Backend, Module, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||||
|
};
|
||||||
|
use itertools::izip;
|
||||||
|
use std::cmp::min;
|
||||||
|
|
||||||
pub trait VecZnxAlloc {
|
pub trait VecZnxAlloc {
|
||||||
/// Allocates a new [VecZnx].
|
/// Allocates a new [VecZnx].
|
||||||
@@ -29,73 +32,86 @@ pub trait VecZnxAlloc {
|
|||||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
|
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait VecZnxOps<DataMut, Data> {
|
pub trait VecZnxOps {
|
||||||
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
||||||
fn vec_znx_normalize(
|
fn vec_znx_normalize<R, A>(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||||
&self,
|
where
|
||||||
log_base2k: usize,
|
R: VecZnxToMut,
|
||||||
res: &mut VecZnx<DataMut>,
|
A: VecZnxToRef;
|
||||||
res_col: usize,
|
|
||||||
a: &VecZnx<Data>,
|
|
||||||
a_col: usize,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Normalizes the selected column of `a`.
|
/// Normalizes the selected column of `a`.
|
||||||
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx<DataMut>, a_col: usize, tmp_bytes: &mut [u8]);
|
fn vec_znx_normalize_inplace<A>(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut;
|
||||||
|
|
||||||
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
|
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
|
||||||
fn vec_znx_add(
|
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnx<DataMut>,
|
R: VecZnxToMut,
|
||||||
res_col: usize,
|
A: VecZnxToRef,
|
||||||
a: &VecZnx<Data>,
|
B: VecZnxToRef;
|
||||||
a_col: usize,
|
|
||||||
b: &VecZnx<Data>,
|
|
||||||
b_col: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
|
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
|
||||||
fn vec_znx_add_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
|
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
|
||||||
fn vec_znx_sub(
|
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnx<DataMut>,
|
R: VecZnxToMut,
|
||||||
res_col: usize,
|
A: VecZnxToRef,
|
||||||
a: &VecZnx<Data>,
|
B: VecZnxToRef;
|
||||||
a_col: usize,
|
|
||||||
b: &VecZnx<Data>,
|
|
||||||
b_col: usize,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
|
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
|
||||||
///
|
///
|
||||||
/// res[res_col] -= a[a_col]
|
/// res[res_col] -= a[a_col]
|
||||||
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
|
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
|
||||||
///
|
///
|
||||||
/// res[res_col] = a[a_col] - res[res_col]
|
/// res[res_col] = a[a_col] - res[res_col]
|
||||||
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
|
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
|
||||||
fn vec_znx_negate(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
/// Negates the selected column of `a`.
|
/// Negates the selected column of `a`.
|
||||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx<DataMut>, a_col: usize);
|
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut;
|
||||||
|
|
||||||
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
||||||
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
/// Multiplies the selected column of `a` by X^k.
|
/// Multiplies the selected column of `a` by X^k.
|
||||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx<DataMut>, a_col: usize);
|
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut;
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
|
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
|
||||||
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize);
|
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx<DataMut>, a_col: usize);
|
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut;
|
||||||
|
|
||||||
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
||||||
///
|
///
|
||||||
@@ -103,14 +119,10 @@ pub trait VecZnxOps<DataMut, Data> {
|
|||||||
///
|
///
|
||||||
/// This method requires that all [VecZnx] of b have the same ring degree
|
/// This method requires that all [VecZnx] of b have the same ring degree
|
||||||
/// and that b.n() * b.len() <= a.n()
|
/// and that b.n() * b.len() <= a.n()
|
||||||
fn vec_znx_split(
|
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||||
&self,
|
where
|
||||||
res: &mut Vec<VecZnx<DataMut>>,
|
R: VecZnxToMut,
|
||||||
res_col: usize,
|
A: VecZnxToRef;
|
||||||
a: &VecZnx<Data>,
|
|
||||||
a_col: usize,
|
|
||||||
buf: &mut VecZnx<DataMut>,
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
||||||
///
|
///
|
||||||
@@ -118,7 +130,15 @@ pub trait VecZnxOps<DataMut, Data> {
|
|||||||
///
|
///
|
||||||
/// This method requires that all [VecZnx] of a have the same ring degree
|
/// This method requires that all [VecZnx] of a have the same ring degree
|
||||||
/// and that a.n() * a.len() <= b.n()
|
/// and that a.n() * a.len() <= b.n()
|
||||||
fn vec_znx_merge(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &Vec<VecZnx<Data>>, a_col: usize);
|
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef;
|
||||||
|
|
||||||
|
fn switch_degree<R, A>(&self, r: &mut R, col_b: usize, a: &A, col_a: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait VecZnxScratch {
|
pub trait VecZnxScratch {
|
||||||
@@ -140,27 +160,23 @@ impl<B: Backend> VecZnxAlloc for Module<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend, DataMut, Data> VecZnxOps<DataMut, Data> for Module<B>
|
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
|
||||||
where
|
fn vec_znx_normalize<R, A>(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||||
Data: AsRef<[u8]>,
|
where
|
||||||
DataMut: AsRef<[u8]> + AsMut<[u8]>,
|
R: VecZnxToMut,
|
||||||
{
|
A: VecZnxToRef,
|
||||||
fn vec_znx_normalize(
|
{
|
||||||
&self,
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
log_base2k: usize,
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
res: &mut VecZnx<DataMut>,
|
|
||||||
res_col: usize,
|
|
||||||
a: &VecZnx<Data>,
|
|
||||||
a_col: usize,
|
|
||||||
tmp_bytes: &mut [u8],
|
|
||||||
) {
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
assert_eq!(res.n(), self.n());
|
assert_eq!(res.n(), self.n());
|
||||||
assert!(tmp_bytes.len() >= <Self as VecZnxScratch>::vec_znx_normalize_tmp_bytes(&self));
|
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_normalize_base2k(
|
vec_znx::vec_znx_normalize_base2k(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
@@ -176,22 +192,44 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx<DataMut>, a_col: usize, tmp_bytes: &mut [u8]) {
|
fn vec_znx_normalize_inplace<A>(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut,
|
||||||
|
{
|
||||||
|
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(a.n(), self.n());
|
||||||
|
}
|
||||||
|
|
||||||
|
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let a_ptr: *const VecZnx<_> = a;
|
vec_znx::vec_znx_normalize_base2k(
|
||||||
Self::vec_znx_normalize(self, log_base2k, a, a_col, &*a_ptr, a_col, tmp_bytes);
|
self.ptr,
|
||||||
|
log_base2k as u64,
|
||||||
|
a.at_mut_ptr(a_col, 0),
|
||||||
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
|
a.at_ptr(a_col, 0),
|
||||||
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
|
tmp_bytes.as_mut_ptr(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_add(
|
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnx<DataMut>,
|
R: VecZnxToMut,
|
||||||
res_col: usize,
|
A: VecZnxToRef,
|
||||||
a: &VecZnx<Data>,
|
B: VecZnxToRef,
|
||||||
a_col: usize,
|
{
|
||||||
b: &VecZnx<Data>,
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
b_col: usize,
|
let b: VecZnx<&[u8]> = b.to_ref();
|
||||||
) {
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -215,7 +253,14 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_add_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
|
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -237,15 +282,16 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_sub(
|
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
&self,
|
where
|
||||||
res: &mut VecZnx<DataMut>,
|
R: VecZnxToMut,
|
||||||
res_col: usize,
|
A: VecZnxToRef,
|
||||||
a: &VecZnx<Data>,
|
B: VecZnxToRef,
|
||||||
a_col: usize,
|
{
|
||||||
b: &VecZnx<Data>,
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
b_col: usize,
|
let b: VecZnx<&[u8]> = b.to_ref();
|
||||||
) {
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -269,7 +315,13 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
|
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -291,7 +343,13 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
|
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -313,7 +371,13 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_negate(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
|
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -332,14 +396,35 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx<DataMut>, a_col: usize) {
|
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut,
|
||||||
|
{
|
||||||
|
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(a.n(), self.n());
|
||||||
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
let a_ref: *const VecZnx<_> = a;
|
vec_znx::vec_znx_negate(
|
||||||
Self::vec_znx_negate(self, a, a_col, a_ref.as_ref().unwrap(), a_col);
|
self.ptr,
|
||||||
|
a.at_mut_ptr(a_col, 0),
|
||||||
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
|
a.at_ptr(a_col, 0),
|
||||||
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
|
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -359,7 +444,11 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx<DataMut>, a_col: usize) {
|
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut,
|
||||||
|
{
|
||||||
|
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -378,7 +467,13 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx<DataMut>, res_col: usize, a: &VecZnx<Data>, a_col: usize) {
|
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -398,7 +493,11 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx<DataMut>, a_col: usize) {
|
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||||
|
where
|
||||||
|
A: VecZnxToMut,
|
||||||
|
{
|
||||||
|
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -417,23 +516,24 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_split(
|
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||||
&self,
|
where
|
||||||
res: &mut Vec<VecZnx<DataMut>>,
|
R: VecZnxToMut,
|
||||||
res_col: usize,
|
A: VecZnxToRef,
|
||||||
a: &VecZnx<Data>,
|
{
|
||||||
a_col: usize,
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
buf: &mut VecZnx<DataMut>,
|
|
||||||
) {
|
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||||
let (n_in, n_out) = (a.n(), res[0].n());
|
|
||||||
|
let (mut buf, _) = scratch.tmp_vec_znx(self, 1, a.size());
|
||||||
|
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
n_out < n_in,
|
n_out < n_in,
|
||||||
"invalid a: output ring degree should be smaller"
|
"invalid a: output ring degree should be smaller"
|
||||||
);
|
);
|
||||||
res[1..].iter().for_each(|bi| {
|
res[1..].iter_mut().for_each(|bi| {
|
||||||
debug_assert_eq!(
|
debug_assert_eq!(
|
||||||
bi.n(),
|
bi.to_mut().n(),
|
||||||
n_out,
|
n_out,
|
||||||
"invalid input a: all VecZnx must have the same degree"
|
"invalid input a: all VecZnx must have the same degree"
|
||||||
)
|
)
|
||||||
@@ -441,17 +541,23 @@ where
|
|||||||
|
|
||||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
switch_degree(bi, res_col, a, a_col);
|
self.switch_degree(bi, res_col, &a, a_col);
|
||||||
self.vec_znx_rotate(-1, buf, 0, a, a_col);
|
self.vec_znx_rotate(-1, &mut buf, 0, &a, a_col);
|
||||||
} else {
|
} else {
|
||||||
switch_degree(bi, res_col, buf, a_col);
|
self.switch_degree(bi, res_col, &mut buf, a_col);
|
||||||
<Self as VecZnxOps<DataMut, Data>>::vec_znx_rotate_inplace(self, -1, buf, a_col);
|
self.vec_znx_rotate_inplace(-1, &mut buf, a_col);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_merge(&self, res: &mut VecZnx<DataMut>, res_col: usize, a: &Vec<VecZnx<Data>>, a_col: usize) {
|
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||||
let (n_in, n_out) = (res.n(), a[0].n());
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
|
|
||||||
|
let (n_in, n_out) = (res.n(), a[0].to_ref().n());
|
||||||
|
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
n_out < n_in,
|
n_out < n_in,
|
||||||
@@ -459,18 +565,47 @@ where
|
|||||||
);
|
);
|
||||||
a[1..].iter().for_each(|ai| {
|
a[1..].iter().for_each(|ai| {
|
||||||
debug_assert_eq!(
|
debug_assert_eq!(
|
||||||
ai.n(),
|
ai.to_ref().n(),
|
||||||
n_out,
|
n_out,
|
||||||
"invalid input a: all VecZnx must have the same degree"
|
"invalid input a: all VecZnx must have the same degree"
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
a.iter().enumerate().for_each(|(_, ai)| {
|
a.iter().enumerate().for_each(|(_, ai)| {
|
||||||
switch_degree(res, res_col, ai, a_col);
|
self.switch_degree(&mut res, res_col, ai, a_col);
|
||||||
<Self as VecZnxOps<DataMut, Data>>::vec_znx_rotate_inplace(self, -1, res, res_col);
|
self.vec_znx_rotate_inplace(-1, &mut res, res_col);
|
||||||
});
|
});
|
||||||
|
|
||||||
<Self as VecZnxOps<DataMut, Data>>::vec_znx_rotate_inplace(self, a.len() as i64, res, res_col);
|
self.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||||
|
where
|
||||||
|
R: VecZnxToMut,
|
||||||
|
A: VecZnxToRef,
|
||||||
|
{
|
||||||
|
let a: VecZnx<&[u8]> = a.to_ref();
|
||||||
|
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||||
|
|
||||||
|
let (n_in, n_out) = (a.n(), res.n());
|
||||||
|
let (gap_in, gap_out): (usize, usize);
|
||||||
|
|
||||||
|
if n_in > n_out {
|
||||||
|
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||||
|
} else {
|
||||||
|
(gap_in, gap_out) = (1, n_out / n_in);
|
||||||
|
res.zero();
|
||||||
|
}
|
||||||
|
|
||||||
|
let size: usize = min(a.size(), res.size());
|
||||||
|
|
||||||
|
(0..size).for_each(|i| {
|
||||||
|
izip!(
|
||||||
|
a.at(a_col, i).iter().step_by(gap_in),
|
||||||
|
res.at_mut(res_col, i).iter_mut().step_by(gap_out)
|
||||||
|
)
|
||||||
|
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use rand_distr::num_traits::Zero;
|
use rand_distr::num_traits::Zero;
|
||||||
use std::cmp::min;
|
|
||||||
|
|
||||||
pub trait ZnxInfos {
|
pub trait ZnxInfos {
|
||||||
/// Returns the ring degree of the polynomials.
|
/// Returns the ring degree of the polynomials.
|
||||||
@@ -24,7 +23,9 @@ pub trait ZnxInfos {
|
|||||||
fn poly_count(&self) -> usize {
|
fn poly_count(&self) -> usize {
|
||||||
self.rows() * self.cols() * self.size()
|
self.rows() * self.cols() * self.size()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ZnxSliceSize {
|
||||||
/// Returns the slice size, which is the offset between
|
/// Returns the slice size, which is the offset between
|
||||||
/// two size of the same column.
|
/// two size of the same column.
|
||||||
fn sl(&self) -> usize;
|
fn sl(&self) -> usize;
|
||||||
@@ -129,33 +130,6 @@ where
|
|||||||
impl<T> ZnxZero for T where T: ZnxViewMut {}
|
impl<T> ZnxZero for T where T: ZnxViewMut {}
|
||||||
// impl<T> ZnxRsh for T where T: ZnxZero {}
|
// impl<T> ZnxRsh for T where T: ZnxZero {}
|
||||||
|
|
||||||
pub fn switch_degree<S: Copy, DMut: ZnxViewMut<Scalar = S> + ZnxZero, D: ZnxView<Scalar = S>>(
|
|
||||||
b: &mut DMut,
|
|
||||||
col_b: usize,
|
|
||||||
a: &D,
|
|
||||||
col_a: usize,
|
|
||||||
) {
|
|
||||||
let (n_in, n_out) = (a.n(), b.n());
|
|
||||||
let (gap_in, gap_out): (usize, usize);
|
|
||||||
|
|
||||||
if n_in > n_out {
|
|
||||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
|
||||||
} else {
|
|
||||||
(gap_in, gap_out) = (1, n_out / n_in);
|
|
||||||
b.zero();
|
|
||||||
}
|
|
||||||
|
|
||||||
let size: usize = min(a.size(), b.size());
|
|
||||||
|
|
||||||
(0..size).for_each(|i| {
|
|
||||||
izip!(
|
|
||||||
a.at(col_a, i).iter().step_by(gap_in),
|
|
||||||
b.at_mut(col_b, i).iter_mut().step_by(gap_out)
|
|
||||||
)
|
|
||||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
||||||
|
|
||||||
use crate::Scratch;
|
use crate::Scratch;
|
||||||
|
|||||||
Reference in New Issue
Block a user