mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Support for bivariate convolution & normalization with offset (#126)
* Add bivariate-convolution * Add pair-wise convolution + tests + benches * Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md * cross-base2k normalization with positive offset * clippy & fix CI doctest avx compile error * more streamlined bounds derivation for normalization * Working cross-base2k normalization with pos/neg offset * Update normalization API & tests * Add glwe tensoring test * Add relinearization + preliminary test * Fix GGLWEToGGSW key infos * Add (X,Y) convolution by const (1, Y) poly * Faster normalization test + add bench for cnv_by_const * Update changelog
This commit is contained in:
committed by
GitHub
parent
76424d0ab5
commit
4e90e08a71
@@ -1,120 +1,97 @@
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace,
|
||||
VecZnxDftBytesOf, VecZnxDftZero,
|
||||
},
|
||||
layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos},
|
||||
use crate::layouts::{
|
||||
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Scratch, VecZnxBigToMut,
|
||||
VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxViewMut,
|
||||
};
|
||||
|
||||
impl<BE: Backend> BivariateTensoring<BE> for Module<BE>
|
||||
where
|
||||
Self: BivariateConvolution<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
pub trait CnvPVecAlloc<BE: Backend> {
|
||||
fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, BE>;
|
||||
fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, BE>;
|
||||
}
|
||||
|
||||
pub trait BivariateTensoring<BE: Backend>
|
||||
where
|
||||
Self: BivariateConvolution<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
fn bivariate_tensoring<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||
pub trait CnvPVecBytesOf {
|
||||
fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize;
|
||||
fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait Convolution<BE: Backend> {
|
||||
fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
|
||||
fn cnv_prepare_left<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
|
||||
let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
|
||||
let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
|
||||
R: CnvPVecLToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
|
||||
A: VecZnxToRef + ZnxInfos;
|
||||
|
||||
let res_cols: usize = res.cols();
|
||||
let a_cols: usize = a.cols();
|
||||
let b_cols: usize = b.cols();
|
||||
fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
|
||||
fn cnv_prepare_right<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: CnvPVecRToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
|
||||
A: VecZnxToRef + ZnxInfos;
|
||||
|
||||
assert!(res_cols >= a_cols + b_cols - 1);
|
||||
fn cnv_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize;
|
||||
|
||||
for res_col in 0..res_cols {
|
||||
self.vec_znx_dft_zero(res, res_col);
|
||||
}
|
||||
fn cnv_by_const_apply_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize;
|
||||
|
||||
for a_col in 0..a_cols {
|
||||
for b_col in 0..b_cols {
|
||||
self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> BivariateConvolution<BE> for Module<BE>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
}
|
||||
|
||||
pub trait BivariateConvolution<BE: Backend>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
|
||||
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
|
||||
}
|
||||
/// Evaluates a bivariate convolution over Z[X, Y] (x) Z[Y] mod (X^N + 1) where Y = 2^-K over the
|
||||
/// selected columns and stores the result on the selected column, scaled by 2^{res_offset * Base2K}
|
||||
///
|
||||
/// Behavior is identical to [Convolution::cnv_apply_dft] with `b` treated as a constant polynomial
|
||||
/// in the X variable, for example:
|
||||
///```text
|
||||
/// 1 X X^2 X^3
|
||||
/// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ...
|
||||
/// Y [a01, a11, a21, a31]
|
||||
///
|
||||
/// b = 1 [b0] = (b00 + b01 * 2^-K)
|
||||
/// Y [b0]
|
||||
/// ```
|
||||
/// This method is intended to be used for multiplications by constants that are greater than the base2k.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn cnv_by_const_apply<R, A>(
|
||||
&self,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &[i64],
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the
|
||||
/// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K}
|
||||
/// Evaluates a bivariate convolution over Z[X, Y] (x) Z[X, Y] mod (X^N + 1) where Y = 2^-K over the
|
||||
/// selected columns and stores the result on the selected column, scaled by 2^{res_offset * Base2K}
|
||||
///
|
||||
/// # Example
|
||||
/// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ...
|
||||
/// [a01, a11, a21, a31]
|
||||
///```text
|
||||
/// 1 X X^2 X^3
|
||||
/// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ...
|
||||
/// Y [a01, a11, a21, a31]
|
||||
///
|
||||
/// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ...
|
||||
/// [b01, b11, b21, b31]
|
||||
/// b = 1 [b00, b10, b20, b30] = (b00 + b01 * 2^-K) + (b10 + b11 * 2^-K) * X ...
|
||||
/// Y [b01, b11, b21, b31]
|
||||
///
|
||||
/// If k = 0:
|
||||
/// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ...
|
||||
/// [r01, r11, r21, r31]
|
||||
/// [r02, r12, r22, r32]
|
||||
/// [r03, r13, r23, r33]
|
||||
/// [r04, r14, r24, r34]
|
||||
/// If res_offset = 0:
|
||||
///
|
||||
/// If k = 1:
|
||||
/// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ...
|
||||
/// [r02, r12, r22, r32]
|
||||
/// [r03, r13, r23, r33]
|
||||
/// [r04, r14, r24, r34]
|
||||
/// [r05, r15, r25, r35]
|
||||
/// 1 X X^2 X^3
|
||||
/// res = 1 [r00, r10, r20, r30] = (r00 + r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K) + ... * X + ...
|
||||
/// Y [r01, r11, r21, r31]
|
||||
/// Y^2[r02, r12, r22, r32]
|
||||
/// Y^3[r03, r13, r23, r33]
|
||||
///
|
||||
/// If k = -1:
|
||||
/// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ...
|
||||
/// [ 0, 0, 0, 0]
|
||||
/// [r01, r11, r21, r31]
|
||||
/// [r02, r12, r22, r32]
|
||||
/// [r03, r13, r23, r33]
|
||||
/// If res_offset = 1:
|
||||
///
|
||||
/// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension.
|
||||
fn bivariate_convolution_add<R, A, B>(
|
||||
/// 1 X X^2 X^3
|
||||
/// res = 1 [r01, r11, r21, r31] = (r01 + r02 * 2^-K + r03 * 2^-2K) + ... * X + ...
|
||||
/// Y [r02, r12, r22, r32]
|
||||
/// Y^2[r03, r13, r23, r33]
|
||||
/// Y^3[ 0, 0, 0 , 0]
|
||||
/// ```
|
||||
/// If res.size() < a.size() + b.size() + k, result is truncated accordingly in the Y dimension.
|
||||
fn cnv_apply_dft<R, A, B>(
|
||||
&self,
|
||||
k: i64,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
@@ -123,40 +100,27 @@ where
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
|
||||
let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
|
||||
let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
|
||||
A: CnvPVecLToRef<BE>,
|
||||
B: CnvPVecRToRef<BE>;
|
||||
|
||||
let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1);
|
||||
let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, b.size());
|
||||
|
||||
for a_limb in 0..a.size() {
|
||||
self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0);
|
||||
self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
|
||||
self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k);
|
||||
}
|
||||
}
|
||||
fn cnv_pairwise_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn bivariate_convolution<R, A, B>(
|
||||
/// Evaluates the bivariate pair-wise convolution res = (a[i] + a[j]) * (b[i] + b[j]).
|
||||
/// If i == j then calls [Convolution::cnv_apply_dft], i.e. res = a[i] * b[i].
|
||||
/// See [Convolution::cnv_apply_dft] for information about the bivariate convolution.
|
||||
fn cnv_pairwise_apply_dft<R, A, B>(
|
||||
&self,
|
||||
k: i64,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
i: usize,
|
||||
j: usize,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
self.vec_znx_dft_zero(res, res_col);
|
||||
self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch);
|
||||
}
|
||||
A: CnvPVecLToRef<BE>,
|
||||
B: CnvPVecRToRef<BE>;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
api::{CnvPVecBytesOf, ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
|
||||
layouts::{Backend, CnvPVecL, CnvPVecR, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
};
|
||||
|
||||
/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes.
|
||||
@@ -56,6 +56,22 @@ pub trait ScratchTakeBasic
|
||||
where
|
||||
Self: TakeSlice,
|
||||
{
|
||||
fn take_cnv_pvec_left<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (CnvPVecL<&mut [u8], B>, &mut Self)
|
||||
where
|
||||
M: ModuleN + CnvPVecBytesOf,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_cnv_pvec_left(cols, size));
|
||||
(CnvPVecL::from_data(take_slice, module.n(), cols, size), rem_slice)
|
||||
}
|
||||
|
||||
fn take_cnv_pvec_right<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (CnvPVecR<&mut [u8], B>, &mut Self)
|
||||
where
|
||||
M: ModuleN + CnvPVecBytesOf,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_cnv_pvec_right(cols, size));
|
||||
(CnvPVecR::from_data(take_slice, module.n(), cols, size), rem_slice)
|
||||
}
|
||||
|
||||
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols));
|
||||
(ScalarZnx::from_data(take_slice, n, cols), rem_slice)
|
||||
@@ -79,10 +95,7 @@ where
|
||||
M: VecZnxBigBytesOf + ModuleN,
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size));
|
||||
(
|
||||
VecZnxBig::from_data(take_slice, module.n(), cols, size),
|
||||
rem_slice,
|
||||
)
|
||||
(VecZnxBig::from_data(take_slice, module.n(), cols, size), rem_slice)
|
||||
}
|
||||
|
||||
fn take_vec_znx_dft<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
|
||||
@@ -91,10 +104,7 @@ where
|
||||
{
|
||||
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size));
|
||||
|
||||
(
|
||||
VecZnxDft::from_data(take_slice, module.n(), cols, size),
|
||||
rem_slice,
|
||||
)
|
||||
(VecZnxDft::from_data(take_slice, module.n(), cols, size), rem_slice)
|
||||
}
|
||||
|
||||
fn take_vec_znx_dft_slice<M, B: Backend>(
|
||||
@@ -155,9 +165,6 @@ where
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
|
||||
(
|
||||
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
rem_slice,
|
||||
)
|
||||
(MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), rem_slice)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,11 +19,12 @@ pub trait VecZnxNormalize<B: Backend> {
|
||||
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
||||
fn vec_znx_normalize<R, A>(
|
||||
&self,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
|
||||
@@ -164,11 +164,12 @@ pub trait VecZnxBigNormalizeTmpBytes {
|
||||
pub trait VecZnxBigNormalize<B: Backend> {
|
||||
fn vec_znx_big_normalize<R, A>(
|
||||
&self,
|
||||
res_base2k: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_base2k: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
|
||||
268
poulpy-hal/src/bench_suite/convolution.rs
Normal file
268
poulpy-hal/src/bench_suite/convolution.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{CnvPVecAlloc, Convolution, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxDftAlloc},
|
||||
layouts::{Backend, CnvPVecL, CnvPVecR, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn bench_cnv_prepare_left<BE: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let group_name: String = format!("cnv_prepare_left::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
let c_size: usize = size + size - 1;
|
||||
|
||||
let module: Module<BE> = Module::<BE>::new(n as u64);
|
||||
|
||||
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(1, size);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, size);
|
||||
|
||||
a.fill_uniform(base2k, &mut source);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.cnv_prepare_left_tmp_bytes(c_size, size));
|
||||
|
||||
move || {
|
||||
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
|
||||
let log_n: usize = params[0];
|
||||
let size: usize = params[1];
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
|
||||
let mut runner = runner(1 << log_n, size);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_cnv_prepare_right<BE: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let group_name: String = format!("cnv_prepare_right::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
let c_size: usize = size + size - 1;
|
||||
|
||||
let module: Module<BE> = Module::<BE>::new(n as u64);
|
||||
|
||||
let mut a_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(1, size);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, size);
|
||||
|
||||
a.fill_uniform(base2k, &mut source);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.cnv_prepare_right_tmp_bytes(c_size, size));
|
||||
|
||||
move || {
|
||||
module.cnv_prepare_right(&mut a_prep, &a, scratch.borrow());
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
|
||||
let log_n: usize = params[0];
|
||||
let size: usize = params[1];
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
|
||||
let mut runner = runner(1 << log_n, size);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_cnv_apply_dft<BE: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxDftAlloc<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let group_name: String = format!("cnv_apply_dft::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxDftAlloc<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
let c_size: usize = size + size - 1;
|
||||
|
||||
let module: Module<BE> = Module::<BE>::new(n as u64);
|
||||
|
||||
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(1, size);
|
||||
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(1, size);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, size);
|
||||
let mut c_dft = module.vec_znx_dft_alloc(1, c_size);
|
||||
|
||||
a.fill_uniform(base2k, &mut source);
|
||||
b.fill_uniform(base2k, &mut source);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
|
||||
module
|
||||
.cnv_apply_dft_tmp_bytes(c_size, 0, size, size)
|
||||
.max(module.cnv_prepare_left_tmp_bytes(c_size, size))
|
||||
.max(module.cnv_prepare_right_tmp_bytes(c_size, size)),
|
||||
);
|
||||
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
|
||||
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
|
||||
move || {
|
||||
module.cnv_apply_dft(&mut c_dft, 0, 0, &a_prep, 0, &b_prep, 0, scratch.borrow());
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
|
||||
let log_n: usize = params[0];
|
||||
let size: usize = params[1];
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
|
||||
let mut runner = runner(1 << log_n, size);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_cnv_pairwise_apply_dft<BE: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxDftAlloc<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let group_name: String = format!("cnv_pairwise_apply_dft::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxDftAlloc<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
let module: Module<BE> = Module::<BE>::new(n as u64);
|
||||
|
||||
let cols = 2;
|
||||
let c_size: usize = size + size - 1;
|
||||
|
||||
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(cols, size);
|
||||
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(cols, size);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
|
||||
let mut c_dft = module.vec_znx_dft_alloc(1, c_size);
|
||||
|
||||
a.fill_uniform(base2k, &mut source);
|
||||
b.fill_uniform(base2k, &mut source);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
|
||||
module
|
||||
.cnv_pairwise_apply_dft_tmp_bytes(c_size, 0, size, size)
|
||||
.max(module.cnv_prepare_left_tmp_bytes(c_size, size))
|
||||
.max(module.cnv_prepare_right_tmp_bytes(c_size, size)),
|
||||
);
|
||||
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
|
||||
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
|
||||
move || {
|
||||
module.cnv_pairwise_apply_dft(&mut c_dft, 0, 0, &a_prep, &b_prep, 0, 1, scratch.borrow());
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
|
||||
let log_n: usize = params[0];
|
||||
let size: usize = params[1];
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
|
||||
let mut runner = runner(1 << log_n, size);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_cnv_by_const_apply<BE: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxBigAlloc<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let group_name: String = format!("cnv_by_const::{label}");
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
|
||||
where
|
||||
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxBigAlloc<BE> + CnvPVecAlloc<BE>,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
let module: Module<BE> = Module::<BE>::new(n as u64);
|
||||
|
||||
let cols = 2;
|
||||
let c_size: usize = size + size - 1;
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, c_size);
|
||||
|
||||
a.fill_uniform(base2k, &mut source);
|
||||
let mut b = vec![0i64; size];
|
||||
for x in &mut b {
|
||||
*x = source.next_i64();
|
||||
}
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.cnv_by_const_apply_tmp_bytes(c_size, 0, size, size));
|
||||
move || {
|
||||
module.cnv_by_const_apply(&mut c_big, 0, 0, &a, 0, &b, scratch.borrow());
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
|
||||
let log_n: usize = params[0];
|
||||
let size: usize = params[1];
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
|
||||
let mut runner = runner(1 << log_n, size);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod convolution;
|
||||
pub mod svp;
|
||||
pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
|
||||
@@ -404,7 +404,7 @@ where
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_big_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
|
||||
module.vec_znx_big_normalize(&mut res, base2k, 0, i, &a, base2k, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
|
||||
125
poulpy-hal/src/delegates/convolution.rs
Normal file
125
poulpy-hal/src/delegates/convolution.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use crate::{
|
||||
api::{CnvPVecAlloc, CnvPVecBytesOf, Convolution},
|
||||
layouts::{
|
||||
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnxBigToMut,
|
||||
VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxViewMut,
|
||||
},
|
||||
oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl},
|
||||
};
|
||||
|
||||
impl<BE: Backend> CnvPVecAlloc<BE> for Module<BE>
|
||||
where
|
||||
BE: CnvPVecLAllocImpl<BE>,
|
||||
{
|
||||
fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, BE> {
|
||||
BE::cnv_pvec_left_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
|
||||
fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, BE> {
|
||||
BE::cnv_pvec_right_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> CnvPVecBytesOf for Module<BE>
|
||||
where
|
||||
BE: CnvPVecBytesOfImpl,
|
||||
{
|
||||
fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize {
|
||||
BE::bytes_of_cnv_pvec_left_impl(self.n(), cols, size)
|
||||
}
|
||||
|
||||
fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize {
|
||||
BE::bytes_of_cnv_pvec_right_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> Convolution<BE> for Module<BE>
|
||||
where
|
||||
BE: ConvolutionImpl<BE>,
|
||||
{
|
||||
fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
|
||||
BE::cnv_prepare_left_tmp_bytes_impl(self, res_size, a_size)
|
||||
}
|
||||
fn cnv_prepare_left<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: CnvPVecLToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
|
||||
A: VecZnxToRef + ZnxInfos,
|
||||
{
|
||||
BE::cnv_prepare_left_impl(self, res, a, scratch);
|
||||
}
|
||||
|
||||
fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
|
||||
BE::cnv_prepare_right_tmp_bytes_impl(self, res_size, a_size)
|
||||
}
|
||||
fn cnv_prepare_right<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: CnvPVecRToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
|
||||
A: VecZnxToRef + ZnxInfos,
|
||||
{
|
||||
BE::cnv_prepare_right_impl(self, res, a, scratch);
|
||||
}
|
||||
|
||||
fn cnv_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize {
|
||||
BE::cnv_apply_dft_tmp_bytes_impl(self, res_size, res_offset, a_size, b_size)
|
||||
}
|
||||
|
||||
fn cnv_by_const_apply_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize {
|
||||
BE::cnv_by_const_apply_tmp_bytes_impl(self, res_size, res_offset, a_size, b_size)
|
||||
}
|
||||
|
||||
fn cnv_by_const_apply<R, A>(
|
||||
&self,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &[i64],
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
BE::cnv_by_const_apply_impl(self, res, res_offset, res_col, a, a_col, b, scratch);
|
||||
}
|
||||
|
||||
fn cnv_apply_dft<R, A, B>(
|
||||
&self,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: CnvPVecLToRef<BE>,
|
||||
B: CnvPVecRToRef<BE>,
|
||||
{
|
||||
BE::cnv_apply_dft_impl(self, res, res_offset, res_col, a, a_col, b, b_col, scratch);
|
||||
}
|
||||
|
||||
fn cnv_pairwise_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize {
|
||||
BE::cnv_pairwise_apply_dft_tmp_bytes(self, res_size, res_offset, a_size, b_size)
|
||||
}
|
||||
|
||||
fn cnv_pairwise_apply_dft<R, A, B>(
|
||||
&self,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
b: &B,
|
||||
i: usize,
|
||||
j: usize,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: CnvPVecLToRef<BE>,
|
||||
B: CnvPVecRToRef<BE>,
|
||||
{
|
||||
BE::cnv_pairwise_apply_dft_impl(self, res, res_offset, res_col, a, b, i, j, scratch);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
mod convolution;
|
||||
mod module;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
|
||||
@@ -51,18 +51,19 @@ where
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_normalize<R, A>(
|
||||
&self,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
B::vec_znx_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch)
|
||||
B::vec_znx_normalize_impl(self, res, res_base2k, res_offset, res_col, a, a_base2k, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ where
|
||||
|
||||
impl<B> VecZnxBigBytesOf for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocBytesImpl<B>,
|
||||
B: Backend + VecZnxBigAllocBytesImpl,
|
||||
{
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_big_bytes_of_impl(self.n(), cols, size)
|
||||
@@ -264,18 +264,19 @@ where
|
||||
{
|
||||
fn vec_znx_big_normalize<R, A>(
|
||||
&self,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<B>,
|
||||
{
|
||||
B::vec_znx_big_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch);
|
||||
B::vec_znx_big_normalize_impl(self, res, res_base2k, res_offset, res_col, a, a_base2k, a_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,9 +76,7 @@ where
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_dft_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
B::vmp_apply_dft_tmp_bytes_impl(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,9 +107,7 @@ where
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
B::vmp_apply_dft_to_dft_tmp_bytes_impl(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,9 +138,7 @@ where
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
237
poulpy-hal/src/layouts/convolution.rs
Normal file
237
poulpy-hal/src/layouts/convolution.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ZnxInfos, ZnxView},
|
||||
oep::CnvPVecBytesOfImpl,
|
||||
};
|
||||
|
||||
pub struct CnvPVecR<D: Data, BE: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
cols: usize,
|
||||
_phantom: PhantomData<BE>,
|
||||
}
|
||||
|
||||
impl<D: Data, BE: Backend> ZnxInfos for CnvPVecR<D, BE> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, BE: Backend> DataView for CnvPVecR<D, BE> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for CnvPVecR<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, BE: Backend> ZnxView for CnvPVecR<D, BE> {
|
||||
type Scalar = BE::ScalarPrep;
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> CnvPVecR<D, B>
|
||||
where
|
||||
B: CnvPVecBytesOfImpl,
|
||||
{
|
||||
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(B::bytes_of_cnv_pvec_right_impl(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == B::bytes_of_cnv_pvec_right_impl(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> CnvPVecR<D, B> {
|
||||
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CnvPVecL<D: Data, BE: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
cols: usize,
|
||||
_phantom: PhantomData<BE>,
|
||||
}
|
||||
|
||||
impl<D: Data, BE: Backend> ZnxInfos for CnvPVecL<D, BE> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, BE: Backend> DataView for CnvPVecL<D, BE> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> DataViewMut for CnvPVecL<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef, BE: Backend> ZnxView for CnvPVecL<D, BE> {
|
||||
type Scalar = BE::ScalarPrep;
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> CnvPVecL<D, B>
|
||||
where
|
||||
B: CnvPVecBytesOfImpl,
|
||||
{
|
||||
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(B::bytes_of_cnv_pvec_left_impl(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == B::bytes_of_cnv_pvec_left_impl(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
size,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> CnvPVecL<D, B> {
|
||||
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CnvPVecRToRef<BE: Backend> {
|
||||
fn to_ref(&self) -> CnvPVecR<&[u8], BE>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, BE: Backend> CnvPVecRToRef<BE> for CnvPVecR<D, BE> {
|
||||
fn to_ref(&self) -> CnvPVecR<&[u8], BE> {
|
||||
CnvPVecR {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
size: self.size,
|
||||
cols: self.cols,
|
||||
_phantom: self._phantom,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CnvPVecRToMut<BE: Backend> {
|
||||
fn to_mut(&mut self) -> CnvPVecR<&mut [u8], BE>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, BE: Backend> CnvPVecRToMut<BE> for CnvPVecR<D, BE> {
|
||||
fn to_mut(&mut self) -> CnvPVecR<&mut [u8], BE> {
|
||||
CnvPVecR {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
size: self.size,
|
||||
cols: self.cols,
|
||||
_phantom: self._phantom,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CnvPVecLToRef<BE: Backend> {
|
||||
fn to_ref(&self) -> CnvPVecL<&[u8], BE>;
|
||||
}
|
||||
|
||||
impl<D: DataRef, BE: Backend> CnvPVecLToRef<BE> for CnvPVecL<D, BE> {
|
||||
fn to_ref(&self) -> CnvPVecL<&[u8], BE> {
|
||||
CnvPVecL {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
size: self.size,
|
||||
cols: self.cols,
|
||||
_phantom: self._phantom,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CnvPVecLToMut<BE: Backend> {
|
||||
fn to_mut(&mut self) -> CnvPVecL<&mut [u8], BE>;
|
||||
}
|
||||
|
||||
impl<D: DataMut, BE: Backend> CnvPVecLToMut<BE> for CnvPVecL<D, BE> {
|
||||
fn to_mut(&mut self) -> CnvPVecL<&mut [u8], BE> {
|
||||
CnvPVecL {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
size: self.size,
|
||||
cols: self.cols,
|
||||
_phantom: self._phantom,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,22 +223,22 @@ impl<D: DataRef> VecZnx<D> {
|
||||
|
||||
let a: VecZnx<&[u8]> = self.to_ref();
|
||||
let size: usize = a.size();
|
||||
let prec: u32 = (base2k * size) as u32;
|
||||
let prec: u32 = data[0].prec();
|
||||
|
||||
// 2^{base2k}
|
||||
let base: Float = Float::with_val(prec, (1u64 << base2k) as f64);
|
||||
let scale: Float = Float::with_val(prec, Float::u_pow_u(2, base2k as u32));
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-base2k*j}
|
||||
(0..size).for_each(|i| {
|
||||
if i == 0 {
|
||||
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
y.assign(*x);
|
||||
*y /= &base;
|
||||
*y /= &scale;
|
||||
});
|
||||
} else {
|
||||
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y += Float::with_val(prec, *x);
|
||||
*y /= &base;
|
||||
*y /= &scale;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod convolution;
|
||||
mod encoding;
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
@@ -12,6 +13,7 @@ mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod znx_base;
|
||||
|
||||
pub use convolution::*;
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
pub use scalar_znx::*;
|
||||
|
||||
@@ -123,10 +123,8 @@ where
|
||||
panic!("cannot invert 0")
|
||||
}
|
||||
|
||||
let g_exp: u64 = mod_exp_u64(
|
||||
gal_el.unsigned_abs(),
|
||||
(self.cyclotomic_order() - 1) as usize,
|
||||
) & (self.cyclotomic_order() - 1) as u64;
|
||||
let g_exp: u64 =
|
||||
mod_exp_u64(gal_el.unsigned_abs(), (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64;
|
||||
g_exp as i64 * gal_el.signum()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,11 +187,7 @@ impl<D: Data> VecZnx<D> {
|
||||
|
||||
impl<D: DataRef> fmt::Display for VecZnx<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnx(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
writeln!(f, "VecZnx(n={}, cols={}, size={})", self.n, self.cols, self.size)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {col}:")?;
|
||||
|
||||
@@ -93,7 +93,7 @@ where
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxBig<D, B>
|
||||
where
|
||||
B: VecZnxBigAllocBytesImpl<B>,
|
||||
B: VecZnxBigAllocBytesImpl,
|
||||
{
|
||||
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(B::vec_znx_big_bytes_of_impl(n, cols, size));
|
||||
@@ -172,11 +172,7 @@ impl<D: DataMut, B: Backend> VecZnxBigToMut<B> for VecZnxBig<D, B> {
|
||||
|
||||
impl<D: DataRef, B: Backend> fmt::Display for VecZnxBig<D, B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxBig(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
writeln!(f, "VecZnxBig(n={}, cols={}, size={})", self.n, self.cols, self.size)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {col}:")?;
|
||||
|
||||
@@ -192,11 +192,7 @@ impl<D: DataMut, B: Backend> VecZnxDftToMut<B> for VecZnxDft<D, B> {
|
||||
|
||||
impl<D: DataRef, B: Backend> fmt::Display for VecZnxDft<D, B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxDft(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
writeln!(f, "VecZnxDft(n={}, cols={}, size={})", self.n, self.cols, self.size)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {col}:")?;
|
||||
|
||||
@@ -65,11 +65,8 @@ pub trait ZnxView: ZnxInfos + DataView<D: DataRef> {
|
||||
|
||||
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "size: {} >= {}", j, self.size());
|
||||
}
|
||||
assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
|
||||
assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
|
||||
let offset: usize = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_ptr().add(offset) }
|
||||
}
|
||||
@@ -93,11 +90,8 @@ pub trait ZnxViewMut: ZnxView + DataViewMut<D: DataMut> {
|
||||
|
||||
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
|
||||
assert!(j < self.size(), "size: {} >= {}", j, self.size());
|
||||
}
|
||||
assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
|
||||
assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
|
||||
let offset: usize = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_mut_ptr().add(offset) }
|
||||
}
|
||||
|
||||
@@ -54,10 +54,7 @@ pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
|
||||
/// Alignement must be a power of two and size a multiple of the alignement.
|
||||
/// Allocated memory is initialized to zero.
|
||||
fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
|
||||
assert!(
|
||||
align.is_power_of_two(),
|
||||
"Alignment must be a power of two but is {align}"
|
||||
);
|
||||
assert!(align.is_power_of_two(), "Alignment must be a power of two but is {align}");
|
||||
assert_eq!(
|
||||
(size * size_of::<u8>()) % align,
|
||||
0,
|
||||
@@ -82,10 +79,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
|
||||
/// Allocates a block of T aligned with [DEFAULTALIGN].
|
||||
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
|
||||
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||
assert!(
|
||||
align.is_power_of_two(),
|
||||
"Alignment must be a power of two but is {align}"
|
||||
);
|
||||
assert!(align.is_power_of_two(), "Alignment must be a power of two but is {align}");
|
||||
|
||||
assert_eq!(
|
||||
(size * size_of::<T>()) % align,
|
||||
|
||||
106
poulpy-hal/src/oep/convolution.rs
Normal file
106
poulpy-hal/src/oep/convolution.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
use crate::layouts::{
|
||||
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnxBigToMut,
|
||||
VecZnxDftToMut, VecZnxToRef, ZnxInfos,
|
||||
};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See the TODO reference implementation.
|
||||
/// * See [crate::api::CnvPVecLAlloc] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait CnvPVecLAllocImpl<BE: Backend> {
|
||||
fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, BE>;
|
||||
fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, BE>;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See the TODO reference implementation.
|
||||
/// * See [crate::api::CnvPVecLBytesOf] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait CnvPVecBytesOfImpl {
|
||||
fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See the TODO reference implementation.
|
||||
/// * See [crate::api::Convolution] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ConvolutionImpl<BE: Backend> {
|
||||
fn cnv_prepare_left_tmp_bytes_impl(module: &Module<BE>, res_size: usize, a_size: usize) -> usize;
|
||||
fn cnv_prepare_left_impl<R, A>(module: &Module<BE>, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: CnvPVecLToMut<BE>,
|
||||
A: VecZnxToRef;
|
||||
fn cnv_prepare_right_tmp_bytes_impl(module: &Module<BE>, res_size: usize, a_size: usize) -> usize;
|
||||
fn cnv_prepare_right_impl<R, A>(module: &Module<BE>, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: CnvPVecRToMut<BE>,
|
||||
A: VecZnxToRef + ZnxInfos;
|
||||
fn cnv_apply_dft_tmp_bytes_impl(
|
||||
module: &Module<BE>,
|
||||
res_size: usize,
|
||||
res_offset: usize,
|
||||
a_size: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
|
||||
fn cnv_by_const_apply_tmp_bytes_impl(
|
||||
module: &Module<BE>,
|
||||
res_size: usize,
|
||||
res_offset: usize,
|
||||
a_size: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn cnv_by_const_apply_impl<R, A>(
|
||||
module: &Module<BE>,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &[i64],
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn cnv_apply_dft_impl<R, A, B>(
|
||||
module: &Module<BE>,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: CnvPVecLToRef<BE>,
|
||||
B: CnvPVecRToRef<BE>;
|
||||
fn cnv_pairwise_apply_dft_tmp_bytes(
|
||||
module: &Module<BE>,
|
||||
res_size: usize,
|
||||
res_offset: usize,
|
||||
a_size: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn cnv_pairwise_apply_dft_impl<R, A, B>(
|
||||
module: &Module<BE>,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
b: &B,
|
||||
i: usize,
|
||||
j: usize,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: CnvPVecLToRef<BE>,
|
||||
B: CnvPVecRToRef<BE>;
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
mod convolution;
|
||||
mod module;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
@@ -6,6 +7,7 @@ mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
|
||||
pub use convolution::*;
|
||||
pub use module::*;
|
||||
pub use scratch::*;
|
||||
pub use svp_ppol::*;
|
||||
|
||||
@@ -29,11 +29,12 @@ pub unsafe trait VecZnxNormalizeImpl<B: Backend> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
|
||||
@@ -34,7 +34,7 @@ pub unsafe trait VecZnxBigFromBytesImpl<B: Backend> {
|
||||
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
|
||||
/// * See [crate::api::VecZnxBigAllocBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxBigAllocBytesImpl<B: Backend> {
|
||||
pub unsafe trait VecZnxBigAllocBytesImpl {
|
||||
fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ pub unsafe trait VecZnxBigAllocBytesImpl<B: Backend> {
|
||||
pub unsafe trait VecZnxBigAddNormalImpl<B: Backend> {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<B>>(
|
||||
module: &Module<B>,
|
||||
res_basek: usize,
|
||||
res_base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
@@ -240,11 +240,12 @@ pub unsafe trait VecZnxBigNormalizeTmpBytesImpl<B: Backend> {
|
||||
pub unsafe trait VecZnxBigNormalizeImpl<B: Backend> {
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<B>,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
|
||||
405
poulpy-hal/src/reference/fft64/convolution.rs
Normal file
405
poulpy-hal/src/reference/fft64/convolution.rs
Normal file
@@ -0,0 +1,405 @@
|
||||
use crate::{
|
||||
layouts::{
|
||||
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, VecZnx, VecZnxBig,
|
||||
VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
reference::fft64::{
|
||||
reim::{ReimAdd, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
|
||||
reim4::{
|
||||
Reim4Convolution, Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4Extract1BlkContiguous,
|
||||
Reim4Save1BlkContiguous,
|
||||
},
|
||||
vec_znx_dft::vec_znx_dft_apply,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn convolution_prepare_left<R, A, T, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, tmp: &mut T)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1BlkContiguous
|
||||
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
|
||||
+ ReimFromZnx
|
||||
+ ReimZero,
|
||||
R: CnvPVecLToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
|
||||
A: VecZnxToRef,
|
||||
T: VecZnxDftToMut<BE>,
|
||||
{
|
||||
convolution_prepare(table, res, a, tmp)
|
||||
}
|
||||
|
||||
pub fn convolution_prepare_right<R, A, T, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, tmp: &mut T)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1BlkContiguous
|
||||
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
|
||||
+ ReimFromZnx
|
||||
+ ReimZero,
|
||||
R: CnvPVecRToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
|
||||
A: VecZnxToRef,
|
||||
T: VecZnxDftToMut<BE>,
|
||||
{
|
||||
convolution_prepare(table, res, a, tmp)
|
||||
}
|
||||
|
||||
fn convolution_prepare<R, A, T, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, tmp: &mut T)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1BlkContiguous
|
||||
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
|
||||
+ ReimFromZnx
|
||||
+ ReimZero,
|
||||
R: ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
|
||||
A: VecZnxToRef,
|
||||
T: VecZnxDftToMut<BE>,
|
||||
{
|
||||
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||
let tmp: &mut VecZnxDft<&mut [u8], BE> = &mut tmp.to_mut();
|
||||
|
||||
let cols: usize = res.cols();
|
||||
assert_eq!(a.cols(), cols, "a.cols():{} != res.cols():{cols}", a.cols());
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let min_size: usize = res_size.min(a.size());
|
||||
|
||||
let m: usize = a.n() >> 1;
|
||||
|
||||
let n: usize = table.m() << 1;
|
||||
|
||||
let res_raw: &mut [f64] = res.raw_mut();
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_dft_apply(table, 1, 0, tmp, 0, a, i);
|
||||
|
||||
let tmp_raw: &[f64] = tmp.raw();
|
||||
let res_col: &mut [f64] = &mut res_raw[i * n * res_size..];
|
||||
|
||||
for blk_i in 0..m / 4 {
|
||||
BE::reim4_extract_1blk_contiguous(m, min_size, blk_i, &mut res_col[blk_i * res_size * 8..], tmp_raw);
|
||||
BE::reim_zero(&mut res_col[blk_i * res_size * 8 + min_size * 8..(blk_i + 1) * res_size * 8]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convolution_by_const_apply_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
|
||||
let min_size: usize = res_size.min(a_size + b_size - 1);
|
||||
size_of::<i64>() * (min_size + a_size) * 8
|
||||
}
|
||||
|
||||
pub fn convolution_by_const_apply<R, A, BE>(
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &[i64],
|
||||
tmp: &mut [i64],
|
||||
) where
|
||||
BE: Backend<ScalarBig = i64>
|
||||
+ I64ConvolutionByConst1Coeff
|
||||
+ I64ConvolutionByConst2Coeffs
|
||||
+ I64Extract1BlkContiguous
|
||||
+ I64Save1BlkContiguous,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: &mut VecZnxBig<&mut [u8], BE> = &mut res.to_mut();
|
||||
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||
|
||||
let n: usize = res.n();
|
||||
assert_eq!(a.n(), n);
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.len();
|
||||
|
||||
let bound: usize = a_size + b_size - 1;
|
||||
let min_size: usize = res_size.min(bound);
|
||||
let offset: usize = res_offset.min(bound);
|
||||
|
||||
let a_sl: usize = n * a.cols();
|
||||
let res_sl: usize = n * res.cols();
|
||||
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
let a_raw: &[i64] = a.raw();
|
||||
|
||||
let a_idx: usize = n * a_col;
|
||||
let res_idx: usize = n * res_col;
|
||||
|
||||
let (res_blk, a_blk) = tmp[..(min_size + a_size) * 8].split_at_mut(min_size * 8);
|
||||
|
||||
for blk_i in 0..n / 8 {
|
||||
BE::i64_extract_1blk_contiguous(a_sl, a_idx, a_size, blk_i, a_blk, a_raw);
|
||||
BE::i64_convolution_by_const(res_blk, min_size, offset, a_blk, a_size, b);
|
||||
BE::i64_save_1blk_contiguous(res_sl, res_idx, min_size, blk_i, res_raw, res_blk);
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
res.zero_at(res_col, j);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convolution_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
|
||||
let min_size: usize = res_size.min(a_size + b_size - 1);
|
||||
size_of::<f64>() * 8 * min_size
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn convolution_apply_dft<R, A, B, BE>(
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
tmp: &mut [f64],
|
||||
) where
|
||||
BE: Backend<ScalarPrep = f64> + Reim4Save1BlkContiguous + Reim4Convolution1Coeff + Reim4Convolution2Coeffs,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: CnvPVecLToRef<BE>,
|
||||
B: CnvPVecRToRef<BE>,
|
||||
{
|
||||
let res: &mut VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
|
||||
let a: &CnvPVecL<&[u8], BE> = &a.to_ref();
|
||||
let b: &CnvPVecR<&[u8], BE> = &b.to_ref();
|
||||
|
||||
let n: usize = res.n();
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
let m: usize = n >> 1;
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
let bound: usize = a_size + b_size - 1;
|
||||
let min_size: usize = res_size.min(bound);
|
||||
let offset: usize = res_offset.min(bound);
|
||||
|
||||
let dst: &mut [f64] = res.raw_mut();
|
||||
let a_raw: &[f64] = a.raw();
|
||||
let b_raw: &[f64] = b.raw();
|
||||
|
||||
let mut a_idx: usize = a_col * n * a_size;
|
||||
let mut b_idx: usize = b_col * n * b_size;
|
||||
let a_offset: usize = a_size * 8;
|
||||
let b_offset: usize = b_size * 8;
|
||||
for blk_i in 0..m / 4 {
|
||||
BE::reim4_convolution(tmp, min_size, offset, &a_raw[a_idx..], a_size, &b_raw[b_idx..], b_size);
|
||||
BE::reim4_save_1blk_contiguous(m, min_size, blk_i, dst, tmp);
|
||||
a_idx += a_offset;
|
||||
b_idx += b_offset;
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
res.zero_at(res_col, j);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convolution_pairwise_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
|
||||
convolution_apply_dft_tmp_bytes(res_size, a_size, b_size) + (a_size + b_size) * size_of::<f64>() * 8
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn convolution_pairwise_apply_dft<R, A, B, BE>(
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
b: &B,
|
||||
col_i: usize,
|
||||
col_j: usize,
|
||||
tmp: &mut [f64],
|
||||
) where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ ReimAdd
|
||||
+ ReimCopy
|
||||
+ Reim4Save1BlkContiguous
|
||||
+ Reim4Convolution1Coeff
|
||||
+ Reim4Convolution2Coeffs,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: CnvPVecLToRef<BE>,
|
||||
B: CnvPVecRToRef<BE>,
|
||||
{
|
||||
if col_i == col_j {
|
||||
convolution_apply_dft(res, res_offset, res_col, a, col_i, b, col_j, tmp);
|
||||
return;
|
||||
}
|
||||
|
||||
let res: &mut VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
|
||||
let a: &CnvPVecL<&[u8], BE> = &a.to_ref();
|
||||
let b: &CnvPVecR<&[u8], BE> = &b.to_ref();
|
||||
|
||||
let n: usize = res.n();
|
||||
let m: usize = n >> 1;
|
||||
|
||||
assert_eq!(a.n(), n);
|
||||
assert_eq!(b.n(), n);
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
assert_eq!(
|
||||
tmp.len(),
|
||||
convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size) / size_of::<f64>()
|
||||
);
|
||||
|
||||
let bound: usize = a_size + b_size - 1;
|
||||
let min_size: usize = res_size.min(bound);
|
||||
let offset: usize = res_offset.min(bound);
|
||||
|
||||
let res_raw: &mut [f64] = res.raw_mut();
|
||||
let a_raw: &[f64] = a.raw();
|
||||
let b_raw: &[f64] = b.raw();
|
||||
|
||||
let a_row_size: usize = a_size * 8;
|
||||
let b_row_size: usize = b_size * 8;
|
||||
|
||||
let mut a0_idx: usize = col_i * n * a_size;
|
||||
let mut a1_idx: usize = col_j * n * a_size;
|
||||
let mut b0_idx: usize = col_i * n * b_size;
|
||||
let mut b1_idx: usize = col_j * n * b_size;
|
||||
|
||||
let (tmp_a, tmp) = tmp.split_at_mut(a_row_size);
|
||||
let (tmp_b, tmp_res) = tmp.split_at_mut(b_row_size);
|
||||
|
||||
for blk_i in 0..m / 4 {
|
||||
let a0: &[f64] = &a_raw[a0_idx..];
|
||||
let a1: &[f64] = &a_raw[a1_idx..];
|
||||
let b0: &[f64] = &b_raw[b0_idx..];
|
||||
let b1: &[f64] = &b_raw[b1_idx..];
|
||||
|
||||
BE::reim_add(tmp_a, &a0[..a_row_size], &a1[..a_row_size]);
|
||||
BE::reim_add(tmp_b, &b0[..b_row_size], &b1[..b_row_size]);
|
||||
|
||||
BE::reim4_convolution(tmp_res, min_size, offset, tmp_a, a_size, tmp_b, b_size);
|
||||
BE::reim4_save_1blk_contiguous(m, min_size, blk_i, res_raw, tmp_res);
|
||||
|
||||
a0_idx += a_row_size;
|
||||
a1_idx += a_row_size;
|
||||
b0_idx += b_row_size;
|
||||
b1_idx += b_row_size;
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
res.zero_at(res_col, j);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait I64Extract1BlkContiguous {
|
||||
fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]);
|
||||
}
|
||||
|
||||
pub trait I64Save1BlkContiguous {
|
||||
fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i64_extract_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
|
||||
debug_assert!(blk < (n >> 3));
|
||||
debug_assert!(dst.len() >= rows * 8, "dst.len(): {} < rows*8: {}", dst.len(), 8 * rows);
|
||||
|
||||
let offset: usize = offset + (blk << 3);
|
||||
|
||||
// src = 8-values chunks spaced by n, dst = sequential 8-values chunks
|
||||
let src_rows = src.chunks_exact(n).take(rows);
|
||||
let dst_chunks = dst.chunks_exact_mut(8).take(rows);
|
||||
|
||||
for (dst_chunk, src_row) in dst_chunks.zip(src_rows) {
|
||||
dst_chunk.copy_from_slice(&src_row[offset..offset + 8]);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i64_save_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
|
||||
debug_assert!(blk < (n >> 3));
|
||||
debug_assert!(src.len() >= rows * 8);
|
||||
|
||||
let offset: usize = offset + (blk << 3);
|
||||
|
||||
// dst = 4-values chunks spaced by m, src = sequential 4-values chunks
|
||||
let dst_rows = dst.chunks_exact_mut(n).take(rows);
|
||||
let src_chunks = src.chunks_exact(8).take(rows);
|
||||
|
||||
for (dst_row, src_chunk) in dst_rows.zip(src_chunks) {
|
||||
dst_row[offset..offset + 8].copy_from_slice(src_chunk);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait I64ConvolutionByConst1Coeff {
|
||||
fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i64_convolution_by_const_1coeff_ref(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
|
||||
dst.fill(0);
|
||||
|
||||
let b_size: usize = b.len();
|
||||
|
||||
if k >= a_size + b_size {
|
||||
return;
|
||||
}
|
||||
let j_min: usize = k.saturating_sub(a_size - 1);
|
||||
let j_max: usize = (k + 1).min(b_size);
|
||||
|
||||
for j in j_min..j_max {
|
||||
let ai: &[i64] = &a[8 * (k - j)..];
|
||||
let bi: i64 = b[j];
|
||||
|
||||
dst[0] = dst[0].wrapping_add(ai[0].wrapping_mul(bi));
|
||||
dst[1] = dst[1].wrapping_add(ai[1].wrapping_mul(bi));
|
||||
dst[2] = dst[2].wrapping_add(ai[2].wrapping_mul(bi));
|
||||
dst[3] = dst[3].wrapping_add(ai[3].wrapping_mul(bi));
|
||||
dst[4] = dst[4].wrapping_add(ai[4].wrapping_mul(bi));
|
||||
dst[5] = dst[5].wrapping_add(ai[5].wrapping_mul(bi));
|
||||
dst[6] = dst[6].wrapping_add(ai[6].wrapping_mul(bi));
|
||||
dst[7] = dst[7].wrapping_add(ai[7].wrapping_mul(bi));
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr_i64<const size: usize>(x: &[i64]) -> &[i64; size] {
|
||||
debug_assert!(x.len() >= size, "x.len():{} < size:{}", x.len(), size);
|
||||
unsafe { &*(x.as_ptr() as *const [i64; size]) }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr_i64_mut<const size: usize>(x: &mut [i64]) -> &mut [i64; size] {
|
||||
debug_assert!(x.len() >= size, "x.len():{} < size:{}", x.len(), size);
|
||||
unsafe { &mut *(x.as_mut_ptr() as *mut [i64; size]) }
|
||||
}
|
||||
|
||||
pub trait I64ConvolutionByConst2Coeffs {
|
||||
fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn i64_convolution_by_const_2coeffs_ref(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
|
||||
i64_convolution_by_const_1coeff_ref(k, as_arr_i64_mut(&mut dst[..8]), a, a_size, b);
|
||||
i64_convolution_by_const_1coeff_ref(k + 1, as_arr_i64_mut(&mut dst[8..]), a, a_size, b);
|
||||
}
|
||||
|
||||
impl<BE: Backend> I64ConvolutionByConst<BE> for BE where Self: I64ConvolutionByConst1Coeff + I64ConvolutionByConst2Coeffs {}
|
||||
|
||||
pub trait I64ConvolutionByConst<BE: Backend>
|
||||
where
|
||||
BE: I64ConvolutionByConst1Coeff + I64ConvolutionByConst2Coeffs,
|
||||
{
|
||||
fn i64_convolution_by_const(dst: &mut [i64], dst_size: usize, offset: usize, a: &[i64], a_size: usize, b: &[i64]) {
|
||||
assert!(a_size > 0);
|
||||
|
||||
for k in (0..dst_size - 1).step_by(2) {
|
||||
BE::i64_convolution_by_const_2coeffs(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b);
|
||||
}
|
||||
|
||||
if !dst_size.is_multiple_of(2) {
|
||||
let k: usize = dst_size - 1;
|
||||
BE::i64_convolution_by_const_1coeff(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod convolution;
|
||||
pub mod reim;
|
||||
pub mod reim4;
|
||||
pub mod svp;
|
||||
|
||||
@@ -12,26 +12,10 @@ pub fn fft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => fft2_ref(
|
||||
as_arr_mut::<2, R>(re),
|
||||
as_arr_mut::<2, R>(im),
|
||||
*as_arr::<2, R>(omg),
|
||||
),
|
||||
4 => fft4_ref(
|
||||
as_arr_mut::<4, R>(re),
|
||||
as_arr_mut::<4, R>(im),
|
||||
*as_arr::<4, R>(omg),
|
||||
),
|
||||
8 => fft8_ref(
|
||||
as_arr_mut::<8, R>(re),
|
||||
as_arr_mut::<8, R>(im),
|
||||
*as_arr::<8, R>(omg),
|
||||
),
|
||||
16 => fft16_ref(
|
||||
as_arr_mut::<16, R>(re),
|
||||
as_arr_mut::<16, R>(im),
|
||||
*as_arr::<16, R>(omg),
|
||||
),
|
||||
2 => fft2_ref(as_arr_mut::<2, R>(re), as_arr_mut::<2, R>(im), *as_arr::<2, R>(omg)),
|
||||
4 => fft4_ref(as_arr_mut::<4, R>(re), as_arr_mut::<4, R>(im), *as_arr::<4, R>(omg)),
|
||||
8 => fft8_ref(as_arr_mut::<8, R>(re), as_arr_mut::<8, R>(im), *as_arr::<8, R>(omg)),
|
||||
16 => fft16_ref(as_arr_mut::<16, R>(re), as_arr_mut::<16, R>(im), *as_arr::<16, R>(omg)),
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
@@ -257,12 +241,7 @@ fn fft_bfs_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mu
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
bitwiddle_fft_ref(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, R>(&omg[pos..]),
|
||||
);
|
||||
bitwiddle_fft_ref(h, &mut re[off..], &mut im[off..], as_arr::<4, R>(&omg[pos..]));
|
||||
pos += 4;
|
||||
}
|
||||
mm = h
|
||||
@@ -289,14 +268,7 @@ fn twiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R],
|
||||
let (im_lhs, im_rhs) = im.split_at_mut(h);
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(
|
||||
&mut re_lhs[i],
|
||||
&mut im_lhs[i],
|
||||
&mut re_rhs[i],
|
||||
&mut im_rhs[i],
|
||||
romg,
|
||||
iomg,
|
||||
);
|
||||
cplx_twiddle(&mut re_lhs[i], &mut im_lhs[i], &mut re_rhs[i], &mut im_rhs[i], romg, iomg);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,26 +11,10 @@ pub fn ifft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => ifft2_ref(
|
||||
as_arr_mut::<2, R>(re),
|
||||
as_arr_mut::<2, R>(im),
|
||||
*as_arr::<2, R>(omg),
|
||||
),
|
||||
4 => ifft4_ref(
|
||||
as_arr_mut::<4, R>(re),
|
||||
as_arr_mut::<4, R>(im),
|
||||
*as_arr::<4, R>(omg),
|
||||
),
|
||||
8 => ifft8_ref(
|
||||
as_arr_mut::<8, R>(re),
|
||||
as_arr_mut::<8, R>(im),
|
||||
*as_arr::<8, R>(omg),
|
||||
),
|
||||
16 => ifft16_ref(
|
||||
as_arr_mut::<16, R>(re),
|
||||
as_arr_mut::<16, R>(im),
|
||||
*as_arr::<16, R>(omg),
|
||||
),
|
||||
2 => ifft2_ref(as_arr_mut::<2, R>(re), as_arr_mut::<2, R>(im), *as_arr::<2, R>(omg)),
|
||||
4 => ifft4_ref(as_arr_mut::<4, R>(re), as_arr_mut::<4, R>(im), *as_arr::<4, R>(omg)),
|
||||
8 => ifft8_ref(as_arr_mut::<8, R>(re), as_arr_mut::<8, R>(im), *as_arr::<8, R>(omg)),
|
||||
16 => ifft16_ref(as_arr_mut::<16, R>(re), as_arr_mut::<16, R>(im), *as_arr::<16, R>(omg)),
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
@@ -72,12 +56,7 @@ fn ifft_bfs_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R],
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
inv_bitwiddle_ifft_ref(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, R>(&omg[pos..]),
|
||||
);
|
||||
inv_bitwiddle_ifft_ref(h, &mut re[off..], &mut im[off..], as_arr::<4, R>(&omg[pos..]));
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
@@ -284,14 +263,7 @@ fn inv_twiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut
|
||||
let (im_lhs, im_rhs) = im.split_at_mut(h);
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(
|
||||
&mut re_lhs[i],
|
||||
&mut im_lhs[i],
|
||||
&mut re_rhs[i],
|
||||
&mut im_rhs[i],
|
||||
romg,
|
||||
iomg,
|
||||
);
|
||||
inv_twiddle(&mut re_lhs[i], &mut im_lhs[i], &mut re_rhs[i], &mut im_rhs[i], romg, iomg);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ pub use zero::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr<const size: usize, R: Float + FloatConst>(x: &[R]) -> &[R; size] {
|
||||
debug_assert!(x.len() >= size);
|
||||
debug_assert!(x.len() >= size, "x.len():{} < size:{}", x.len(), size);
|
||||
unsafe { &*(x.as_ptr() as *const [R; size]) }
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,34 @@
|
||||
use crate::reference::fft64::reim::as_arr;
|
||||
use crate::reference::fft64::reim::{as_arr, as_arr_mut, reim_zero_ref};
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_extract_1blk_from_reim_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
pub fn reim4_extract_1blk_from_reim_contiguous_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= 2 * rows * 4);
|
||||
|
||||
for chunk in dst.chunks_exact_mut(4).take(2 * rows) {
|
||||
chunk.copy_from_slice(&src[offset..offset + 4]);
|
||||
offset += m
|
||||
let offset: usize = blk << 2;
|
||||
|
||||
// src = 4-values chunks spaced by m, dst = sequential 4-values chunks
|
||||
let src_rows = src.chunks_exact(m).take(2 * rows);
|
||||
let dst_chunks = dst.chunks_exact_mut(4).take(2 * rows);
|
||||
|
||||
for (dst_chunk, src_row) in dst_chunks.zip(src_rows) {
|
||||
dst_chunk.copy_from_slice(&src_row[offset..offset + 4]);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_save_1blk_to_reim_contiguous_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(src.len() >= 2 * rows * 4);
|
||||
|
||||
let offset: usize = blk << 2;
|
||||
|
||||
// dst = 4-values chunks spaced by m, src = sequential 4-values chunks
|
||||
let dst_rows = dst.chunks_exact_mut(m).take(2 * rows);
|
||||
let src_chunks = src.chunks_exact(4).take(2 * rows);
|
||||
|
||||
for (dst_row, src_chunk) in dst_rows.zip(src_chunks) {
|
||||
dst_row[offset..offset + 4].copy_from_slice(src_chunk);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +72,7 @@ pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize,
|
||||
debug_assert!(dst.len() >= offset + 3 * m + 4);
|
||||
debug_assert!(src.len() >= 16);
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
let dst_off: &mut [f64] = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[0..4]);
|
||||
} else {
|
||||
@@ -64,7 +83,7 @@ pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize,
|
||||
}
|
||||
|
||||
offset += m;
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
let dst_off: &mut [f64] = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[4..8]);
|
||||
} else {
|
||||
@@ -76,7 +95,7 @@ pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize,
|
||||
|
||||
offset += m;
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
let dst_off: &mut [f64] = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[8..12]);
|
||||
} else {
|
||||
@@ -87,7 +106,7 @@ pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize,
|
||||
}
|
||||
|
||||
offset += m;
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
let dst_off: &mut [f64] = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[12..16]);
|
||||
} else {
|
||||
@@ -132,10 +151,7 @@ pub fn reim4_vec_mat2cols_product_ref(
|
||||
{
|
||||
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows * 16 doubles"
|
||||
);
|
||||
assert!(v.len() >= nrows * 16, "v must be at least nrows * 16 doubles");
|
||||
}
|
||||
|
||||
// zero accumulators
|
||||
@@ -161,11 +177,7 @@ pub fn reim4_vec_mat2cols_2ndcol_product_ref(
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
dst.len() >= 8,
|
||||
"dst must be at least 8 doubles but is {}",
|
||||
dst.len()
|
||||
);
|
||||
assert!(dst.len() >= 8, "dst must be at least 8 doubles but is {}", dst.len());
|
||||
assert!(
|
||||
u.len() >= nrows * 8,
|
||||
"u must be at least nrows={} * 8 doubles but is {}",
|
||||
@@ -201,3 +213,57 @@ pub fn reim4_add_mul(dst: &mut [f64; 8], a: &[f64; 8], b: &[f64; 8]) {
|
||||
dst[k + 4] += ar * bi + ai * br;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_convolution_1coeff_ref(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
|
||||
reim_zero_ref(dst);
|
||||
|
||||
if k >= a_size + b_size {
|
||||
return;
|
||||
}
|
||||
let j_min: usize = k.saturating_sub(a_size - 1);
|
||||
let j_max: usize = (k + 1).min(b_size);
|
||||
|
||||
for j in j_min..j_max {
|
||||
reim4_add_mul(dst, as_arr(&a[8 * (k - j)..]), as_arr(&b[8 * j..]));
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_convolution_2coeffs_ref(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
|
||||
reim4_convolution_1coeff_ref(k, as_arr_mut(dst), a, a_size, b, b_size);
|
||||
reim4_convolution_1coeff_ref(k + 1, as_arr_mut(&mut dst[8..]), a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_add_mul_b_real_const(dst: &mut [f64; 8], a: &[f64; 8], b: f64) {
|
||||
for k in 0..4 {
|
||||
let ar: f64 = a[k];
|
||||
let ai: f64 = a[k + 4];
|
||||
dst[k] += ar * b;
|
||||
dst[k + 4] += ai * b;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_convolution_by_real_const_1coeff_ref(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) {
|
||||
reim_zero_ref(dst);
|
||||
|
||||
let b_size: usize = b.len();
|
||||
|
||||
if k >= a_size + b_size {
|
||||
return;
|
||||
}
|
||||
let j_min: usize = k.saturating_sub(a_size - 1);
|
||||
let j_max: usize = (k + 1).min(b_size);
|
||||
|
||||
for j in j_min..j_max {
|
||||
reim4_add_mul_b_real_const(dst, as_arr(&a[8 * (k - j)..]), b[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_convolution_by_real_const_2coeffs_ref(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) {
|
||||
reim4_convolution_by_real_const_1coeff_ref(k, as_arr_mut(dst), a, a_size, b);
|
||||
reim4_convolution_by_real_const_1coeff_ref(k + 1, as_arr_mut(&mut dst[8..]), a, a_size, b);
|
||||
}
|
||||
|
||||
@@ -2,8 +2,14 @@ mod arithmetic_ref;
|
||||
|
||||
pub use arithmetic_ref::*;
|
||||
|
||||
pub trait Reim4Extract1Blk {
|
||||
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
use crate::{layouts::Backend, reference::fft64::reim::as_arr_mut};
|
||||
|
||||
pub trait Reim4Extract1BlkContiguous {
|
||||
fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Save1BlkContiguous {
|
||||
fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Save1Blk {
|
||||
@@ -25,3 +31,63 @@ pub trait Reim4Mat2ColsProd {
|
||||
pub trait Reim4Mat2Cols2ndColProd {
|
||||
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Convolution1Coeff {
|
||||
fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize);
|
||||
}
|
||||
|
||||
pub trait Reim4Convolution2Coeffs {
|
||||
fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize);
|
||||
}
|
||||
|
||||
pub trait Reim4ConvolutionByRealConst1Coeff {
|
||||
fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4ConvolutionByRealConst2Coeffs {
|
||||
fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]);
|
||||
}
|
||||
|
||||
impl<BE: Backend> Reim4Convolution<BE> for BE where Self: Reim4Convolution1Coeff + Reim4Convolution2Coeffs {}
|
||||
|
||||
pub trait Reim4Convolution<BE: Backend>
|
||||
where
|
||||
BE: Reim4Convolution1Coeff + Reim4Convolution2Coeffs,
|
||||
{
|
||||
fn reim4_convolution(dst: &mut [f64], dst_size: usize, offset: usize, a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
|
||||
assert!(a_size > 0);
|
||||
assert!(b_size > 0);
|
||||
|
||||
for k in (0..dst_size - 1).step_by(2) {
|
||||
BE::reim4_convolution_2coeffs(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b, b_size);
|
||||
}
|
||||
|
||||
if !dst_size.is_multiple_of(2) {
|
||||
let k: usize = dst_size - 1;
|
||||
BE::reim4_convolution_1coeff(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b, b_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> Reim4ConvolutionByRealConst<BE> for BE where
|
||||
Self: Reim4ConvolutionByRealConst1Coeff + Reim4ConvolutionByRealConst2Coeffs
|
||||
{
|
||||
}
|
||||
|
||||
pub trait Reim4ConvolutionByRealConst<BE: Backend>
|
||||
where
|
||||
BE: Reim4ConvolutionByRealConst1Coeff + Reim4ConvolutionByRealConst2Coeffs,
|
||||
{
|
||||
fn reim4_convolution_by_real_const(dst: &mut [f64], dst_size: usize, offset: usize, a: &[f64], a_size: usize, b: &[f64]) {
|
||||
assert!(a_size > 0);
|
||||
|
||||
for k in (0..dst_size - 1).step_by(2) {
|
||||
BE::reim4_convolution_by_real_const_2coeffs(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b);
|
||||
}
|
||||
|
||||
if !dst_size.is_multiple_of(2) {
|
||||
let k: usize = dst_size - 1;
|
||||
BE::reim4_convolution_by_real_const_1coeff(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,13 +9,14 @@ use crate::{
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
|
||||
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace,
|
||||
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_normalize_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace,
|
||||
vec_znx_sub_negate_inplace,
|
||||
},
|
||||
znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNegate,
|
||||
ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero,
|
||||
znx_add_normal_f64_ref,
|
||||
ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep,
|
||||
ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
|
||||
ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero, znx_add_normal_f64_ref,
|
||||
},
|
||||
},
|
||||
source::Source,
|
||||
@@ -231,15 +232,17 @@ where
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
|
||||
2 * n * size_of::<i64>()
|
||||
vec_znx_normalize_tmp_bytes(n)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn vec_znx_big_normalize<R, A, BE>(
|
||||
res_base2k: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_base2k: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
carry: &mut [i64],
|
||||
) where
|
||||
@@ -256,7 +259,9 @@ pub fn vec_znx_big_normalize<R, A, BE>(
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxExtractDigitAddMul
|
||||
+ ZnxNormalizeDigit,
|
||||
+ ZnxNormalizeDigit
|
||||
+ ZnxNormalizeMiddleStepInplace
|
||||
+ ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
@@ -267,7 +272,7 @@ pub fn vec_znx_big_normalize<R, A, BE>(
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_normalize::<_, _, BE>(res_base2k, res, res_col, a_base2k, &a_vznx, a_col, carry);
|
||||
vec_znx_normalize::<_, _, BE>(res, res_base2k, res_offset, res_col, &a_vznx, a_base2k, a_col, carry);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
|
||||
@@ -290,18 +295,13 @@ pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
|
||||
|
||||
let limb: usize = k.div_ceil(base2k) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
znx_add_normal_f64_ref(res.at_mut(res_col, limb), sigma * scale, bound * scale, source)
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_big_add_normal<B>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxBigAddNormal<B>,
|
||||
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
|
||||
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let base2k: usize = 17;
|
||||
@@ -325,12 +325,7 @@ where
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
|
||||
assert!(
|
||||
(std - sigma * sqrt2).abs() < 0.1,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
sigma * sqrt2
|
||||
);
|
||||
assert!((std - sigma * sqrt2).abs() < 0.1, "std={} ~!= {}", std, sigma * sqrt2);
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
@@ -4,7 +4,10 @@ use crate::{
|
||||
oep::VecZnxDftAllocBytesImpl,
|
||||
reference::fft64::{
|
||||
reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
|
||||
reim4::{Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks},
|
||||
reim4::{
|
||||
Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk,
|
||||
Reim4Save2Blks,
|
||||
},
|
||||
vec_znx_dft::vec_znx_dft_apply,
|
||||
},
|
||||
};
|
||||
@@ -17,7 +20,7 @@ pub fn vmp_prepare_tmp_bytes(n: usize) -> usize {
|
||||
|
||||
pub fn vmp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, pmat: &mut R, mat: &A, tmp: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1BlkContiguous,
|
||||
R: VmpPMatToMut<BE>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
@@ -34,13 +37,7 @@ where
|
||||
res.cols_in(),
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.rows(),
|
||||
a.rows(),
|
||||
"res.rows: {} != a.rows: {}",
|
||||
res.rows(),
|
||||
a.rows()
|
||||
);
|
||||
assert_eq!(res.rows(), a.rows(), "res.rows: {} != a.rows: {}", res.rows(), a.rows());
|
||||
assert_eq!(
|
||||
res.cols_out(),
|
||||
a.cols_out(),
|
||||
@@ -48,13 +45,7 @@ where
|
||||
res.cols_out(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size: {} != a.size: {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
assert_eq!(res.size(), a.size(), "res.size: {} != a.size: {}", res.size(), a.size());
|
||||
}
|
||||
|
||||
let nrows: usize = a.cols_in() * a.rows();
|
||||
@@ -70,7 +61,7 @@ pub(crate) fn vmp_prepare_core<REIM>(
|
||||
ncols: usize,
|
||||
tmp: &mut [f64],
|
||||
) where
|
||||
REIM: ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
|
||||
REIM: ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1BlkContiguous,
|
||||
{
|
||||
let m: usize = table.m();
|
||||
let n: usize = m << 1;
|
||||
@@ -99,7 +90,7 @@ pub(crate) fn vmp_prepare_core<REIM>(
|
||||
};
|
||||
|
||||
for blk_i in 0..m >> 2 {
|
||||
REIM::reim4_extract_1blk(m, 1, blk_i, &mut dst[blk_i * offset..], tmp);
|
||||
REIM::reim4_extract_1blk_contiguous(m, 1, blk_i, &mut dst[blk_i * offset..], tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -116,7 +107,7 @@ where
|
||||
+ VecZnxDftAllocBytesImpl<BE>
|
||||
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Extract1BlkContiguous
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
@@ -168,7 +159,7 @@ pub fn vmp_apply_dft_to_dft<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, tmp_bytes
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Extract1BlkContiguous
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
@@ -207,7 +198,7 @@ pub fn vmp_apply_dft_to_dft_add<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, limb_
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Extract1BlkContiguous
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
@@ -239,16 +230,7 @@ where
|
||||
let a_raw: &[f64] = a.raw();
|
||||
let res_raw: &mut [f64] = res.raw_mut();
|
||||
|
||||
vmp_apply_dft_to_dft_core::<false, BE>(
|
||||
n,
|
||||
res_raw,
|
||||
a_raw,
|
||||
pmat_raw,
|
||||
limb_offset,
|
||||
nrows,
|
||||
ncols,
|
||||
tmp_bytes,
|
||||
)
|
||||
vmp_apply_dft_to_dft_core::<false, BE>(n, res_raw, a_raw, pmat_raw, limb_offset, nrows, ncols, tmp_bytes)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -263,7 +245,7 @@ fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
|
||||
tmp_bytes: &mut [f64],
|
||||
) where
|
||||
REIM: ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Extract1BlkContiguous
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
@@ -299,41 +281,23 @@ fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
|
||||
for blk_i in 0..(m >> 2) {
|
||||
let mat_blk_start: &[f64] = &pmat[blk_i * (8 * nrows * ncols)..];
|
||||
|
||||
REIM::reim4_extract_1blk(m, row_max, blk_i, extracted_blk, a);
|
||||
REIM::reim4_extract_1blk_contiguous(m, row_max, blk_i, extracted_blk, a);
|
||||
|
||||
if limb_offset.is_multiple_of(2) {
|
||||
for (col_res, col_pmat) in (0..).step_by(2).zip((limb_offset..col_max - 1).step_by(2)) {
|
||||
let col_offset: usize = col_pmat * (8 * nrows);
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_mat2cols_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
|
||||
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
|
||||
}
|
||||
} else {
|
||||
let col_offset: usize = (limb_offset - 1) * (8 * nrows);
|
||||
REIM::reim4_mat2cols_2ndcol_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_mat2cols_2ndcol_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
|
||||
|
||||
REIM::reim4_save_1blk::<OVERWRITE>(m, blk_i, res, mat2cols_output);
|
||||
|
||||
for (col_res, col_pmat) in (1..)
|
||||
.step_by(2)
|
||||
.zip((limb_offset + 1..col_max - 1).step_by(2))
|
||||
{
|
||||
for (col_res, col_pmat) in (1..).step_by(2).zip((limb_offset + 1..col_max - 1).step_by(2)) {
|
||||
let col_offset: usize = col_pmat * (8 * nrows);
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_mat2cols_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
|
||||
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
|
||||
}
|
||||
}
|
||||
@@ -344,26 +308,11 @@ fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
|
||||
|
||||
if last_col >= limb_offset {
|
||||
if ncols == col_max {
|
||||
REIM::reim4_mat1col_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_mat1col_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
|
||||
} else {
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_mat2cols_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
|
||||
}
|
||||
REIM::reim4_save_1blk::<OVERWRITE>(
|
||||
m,
|
||||
blk_i,
|
||||
&mut res[(last_col - limb_offset) * n..],
|
||||
mat2cols_output,
|
||||
);
|
||||
REIM::reim4_save_1blk::<OVERWRITE>(m, blk_i, &mut res[(last_col - limb_offset) * n..], mat2cols_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,10 +24,7 @@ where
|
||||
{
|
||||
assert_eq!(tmp.len(), res.n());
|
||||
|
||||
debug_assert!(
|
||||
_n_out > _n_in,
|
||||
"invalid a: output ring degree should be greater"
|
||||
);
|
||||
debug_assert!(_n_out > _n_in, "invalid a: output ring degree should be greater");
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.to_ref().n(),
|
||||
|
||||
@@ -14,15 +14,17 @@ use crate::{
|
||||
};
|
||||
|
||||
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
|
||||
2 * n * size_of::<i64>()
|
||||
3 * n * size_of::<i64>()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn vec_znx_normalize<R, A, ZNXARI>(
|
||||
res_base2k: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_base2k: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
carry: &mut [i64],
|
||||
) where
|
||||
@@ -38,14 +40,40 @@ pub fn vec_znx_normalize<R, A, ZNXARI>(
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxExtractDigitAddMul
|
||||
+ ZnxNormalizeMiddleStepInplace
|
||||
+ ZnxNormalizeFinalStepInplace
|
||||
+ ZnxNormalizeDigit,
|
||||
{
|
||||
match res_base2k == a_base2k {
|
||||
true => vec_znx_normalize_inter_base2k::<R, A, ZNXARI>(res_base2k, res, res_offset, res_col, a, a_col, carry),
|
||||
false => vec_znx_normalize_cross_base2k::<R, A, ZNXARI>(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry),
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_normalize_inter_base2k<R, A, ZNXARI>(
|
||||
base2k: usize,
|
||||
res: &mut R,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
carry: &mut [i64],
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeFinalStepInplace
|
||||
+ ZnxNormalizeMiddleStepInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(carry.len() >= 2 * res.n());
|
||||
assert!(carry.len() >= vec_znx_normalize_tmp_bytes(res.n()) / size_of::<i64>());
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
@@ -53,153 +81,323 @@ pub fn vec_znx_normalize<R, A, ZNXARI>(
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let carry: &mut [i64] = &mut carry[..2 * n];
|
||||
let (carry, _) = carry.split_at_mut(n);
|
||||
|
||||
if res_base2k == a_base2k {
|
||||
if a_size > res_size {
|
||||
for j in (res_size..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
|
||||
}
|
||||
}
|
||||
let mut lsh: i64 = res_offset % base2k as i64;
|
||||
let mut limbs_offset: i64 = res_offset / base2k as i64;
|
||||
|
||||
for j in (1..res_size).rev() {
|
||||
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
}
|
||||
// If res_offset is negative, makes it positive
|
||||
// and corrects by adding an additional offset
|
||||
// on the limbs.
|
||||
if res_offset < 0 && lsh != 0 {
|
||||
lsh = (lsh + base2k as i64) % (base2k as i64);
|
||||
limbs_offset -= 1;
|
||||
}
|
||||
|
||||
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
|
||||
let lsh_pos: usize = lsh as usize;
|
||||
|
||||
let res_end: usize = (-limbs_offset).clamp(0, res_size as i64) as usize;
|
||||
let res_start: usize = (a_size as i64 - limbs_offset).clamp(0, res_size as i64) as usize;
|
||||
let a_end: usize = limbs_offset.clamp(0, a_size as i64) as usize;
|
||||
let a_start: usize = (res_size as i64 + limbs_offset).clamp(0, a_size as i64) as usize;
|
||||
|
||||
let a_out_range: usize = a_size.saturating_sub(a_start);
|
||||
|
||||
// Computes the carry over the discarded limbs of a
|
||||
for j in 0..a_out_range {
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh_pos, a.at(a_col, a_size - j - 1), carry);
|
||||
} else {
|
||||
for j in (0..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
}
|
||||
}
|
||||
|
||||
for j in a_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh_pos, a.at(a_col, a_size - j - 1), carry);
|
||||
}
|
||||
} else {
|
||||
let (a_norm, carry) = carry.split_at_mut(n);
|
||||
}
|
||||
|
||||
// Relevant limbs of res
|
||||
let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size);
|
||||
// If no limbs were discarded, initialize carry to zero
|
||||
if a_out_range == 0 {
|
||||
ZNXARI::znx_zero(carry);
|
||||
}
|
||||
|
||||
// Relevant limbs of a
|
||||
let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size);
|
||||
// Zeroes bottom limbs that will not be interacted with
|
||||
for j in res_start..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
|
||||
// Get carry for limbs of a that have higher precision than res
|
||||
for j in (a_min_size..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(a_base2k, 0, a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, 0, a.at(a_col, j), carry);
|
||||
let mid_range: usize = a_start.saturating_sub(a_end);
|
||||
|
||||
// Regular normalization over the overlapping limbs of res and a.
|
||||
for j in 0..mid_range {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
base2k,
|
||||
lsh_pos,
|
||||
res.at_mut(res_col, res_start - j - 1),
|
||||
a.at(a_col, a_start - j - 1),
|
||||
carry,
|
||||
);
|
||||
}
|
||||
|
||||
// Propagates the carry over the non-overlapping limbs between res and a
|
||||
for j in 0..res_end {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, res_end - j - 1));
|
||||
if j == res_end - 1 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, lsh_pos, res.at_mut(res_col, res_end - j - 1), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh_pos, res.at_mut(res_col, res_end - j - 1), carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vec_znx_normalize_cross_base2k<R, A, ZNXARI>(
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
carry: &mut [i64],
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxCopy
|
||||
+ ZnxAddInplace
|
||||
+ ZnxMulPowerOfTwoInplace
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxExtractDigitAddMul
|
||||
+ ZnxNormalizeMiddleStepInplace
|
||||
+ ZnxNormalizeFinalStepInplace
|
||||
+ ZnxNormalizeDigit,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(carry.len() >= vec_znx_normalize_tmp_bytes(res.n()) / size_of::<i64>());
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let n: usize = res.n();
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let (a_norm, carry) = carry.split_at_mut(n);
|
||||
let (res_carry, a_carry) = carry[..2 * n].split_at_mut(n);
|
||||
ZNXARI::znx_zero(res_carry);
|
||||
|
||||
// Total precision (in bits) that `a` and `res` can represent.
|
||||
let a_tot_bits: usize = a_size * a_base2k;
|
||||
let res_tot_bits: usize = res_size * res_base2k;
|
||||
|
||||
// Derive intra-limb shift and cross-limb offset.
|
||||
let mut lsh: i64 = res_offset % a_base2k as i64;
|
||||
let mut limbs_offset: i64 = res_offset / a_base2k as i64;
|
||||
|
||||
// If res_offset is negative, ensures it is positive
|
||||
// and corrects by incrementing the cross-limb offset.
|
||||
if res_offset < 0 && lsh != 0 {
|
||||
lsh = (lsh + a_base2k as i64) % (a_base2k as i64);
|
||||
limbs_offset -= 1;
|
||||
}
|
||||
|
||||
let lsh_pos: usize = lsh as usize;
|
||||
|
||||
// Derive start/stop bit indexes of the overlap between `a` and `res` (after taking into account the offset)..
|
||||
let res_end_bit: usize = (-limbs_offset * a_base2k as i64).clamp(0, res_tot_bits as i64) as usize; // Stop bit
|
||||
let res_start_bit: usize = (a_tot_bits as i64 - limbs_offset * a_base2k as i64).clamp(0, res_tot_bits as i64) as usize; // Start bit
|
||||
let a_end_bit: usize = (limbs_offset * a_base2k as i64).clamp(0, a_tot_bits as i64) as usize; // Stop bit
|
||||
let a_start_bit: usize = (res_tot_bits as i64 + limbs_offset * a_base2k as i64).clamp(0, a_tot_bits as i64) as usize; // Start bit
|
||||
|
||||
// Convert bits to limb indexes.
|
||||
let res_end: usize = res_end_bit / res_base2k;
|
||||
let res_start: usize = res_start_bit.div_ceil(res_base2k);
|
||||
let a_end: usize = a_end_bit / a_base2k;
|
||||
let a_start: usize = a_start_bit.div_ceil(a_base2k);
|
||||
|
||||
// Zero all limbs of `res`. Unlike the simple case
|
||||
// where `res_base2k` is equal to `a_base2k`, we also
|
||||
// need to ensure that the limbs starting from `res_end`
|
||||
// are zero.
|
||||
for j in 0..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
|
||||
// Case where offset is positive and greater or equal
|
||||
// to the precision of a.
|
||||
if res_start == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
// Limbs of `a` that have a greater precision than `res`.
|
||||
let a_out_range: usize = a_size.saturating_sub(a_start);
|
||||
|
||||
for j in 0..a_out_range {
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(a_base2k, lsh_pos, a.at(a_col, a_size - j - 1), a_carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, lsh_pos, a.at(a_col, a_size - j - 1), a_carry);
|
||||
}
|
||||
}
|
||||
|
||||
// Zero carry if the above loop didn't trigger.
|
||||
if a_out_range == 0 {
|
||||
ZNXARI::znx_zero(a_carry);
|
||||
}
|
||||
|
||||
// How much is left to accumulate to fill a limb of `res`.
|
||||
let mut res_acc_left: usize = res_base2k;
|
||||
|
||||
// Starting limb of `res`.
|
||||
let mut res_limb: usize = res_start - 1;
|
||||
|
||||
// How many limbs of `a` overlap with `res` (after taking into account the offset).
|
||||
let mid_range: usize = a_start.saturating_sub(a_end);
|
||||
|
||||
// Regular normalization over the overlapping limbs of res and a.
|
||||
'outer: for j in 0..mid_range {
|
||||
let a_limb: usize = a_start - j - 1;
|
||||
|
||||
// Current res & a limbs
|
||||
let a_slice: &[i64] = a.at(a_col, a_limb);
|
||||
|
||||
// Trackers: wow much of a_norm is left to
|
||||
// be flushed on res.
|
||||
let mut a_take_left: usize = a_base2k;
|
||||
|
||||
// Normalizes the j-th limb of a and store the results into `a_norm``.
|
||||
// This step is required to avoid overflow in the next step,
|
||||
// which assumes that |a| is bounded by 2^{a_base2k -1} (i.e. normalized).
|
||||
ZNXARI::znx_normalize_middle_step(a_base2k, lsh_pos, a_norm, a_slice, a_carry);
|
||||
|
||||
// In the first iteration we need to match the precision `res` and `a`.
|
||||
if j == 0 {
|
||||
// Case where `a` has more precision than `res` (after taking into account the offset)
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// a: [x x x x x][x x x x x][x x x x x][x x x x x]
|
||||
// res: [x x x x x x][x x x x x x][x x x x x x]
|
||||
if !(a_tot_bits - a_start_bit).is_multiple_of(a_base2k) {
|
||||
let take: usize = (a_tot_bits - a_start_bit) % a_base2k;
|
||||
ZNXARI::znx_mul_power_of_two_inplace(-(take as i64), a_norm);
|
||||
a_take_left -= take;
|
||||
// Case where `res` has more precision than `a` (after taking into account the offset)
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// a: [x x x x x][x x x x x][x x x x x][x x x x x]
|
||||
// res: [x x x x x x][x x x x x x][x x x x x x]
|
||||
} else if !(res_tot_bits - res_start_bit).is_multiple_of(res_base2k) {
|
||||
res_acc_left -= (res_tot_bits - res_start_bit) % res_base2k;
|
||||
}
|
||||
}
|
||||
|
||||
if a_min_size == a_size {
|
||||
ZNXARI::znx_zero(carry);
|
||||
}
|
||||
// Extract bits of `a_norm` and accumulates them on res[res_limb] until
|
||||
// res_base2k bits have been accumulated or until all bits of `a` are
|
||||
// extracted.
|
||||
'inner: loop {
|
||||
// Current limb of res
|
||||
let res_slice: &mut [i64] = res.at_mut(res_col, res_limb);
|
||||
|
||||
// Maximum relevant precision of a
|
||||
let a_prec: usize = a_min_size * a_base2k;
|
||||
// We can take at most a_base2k bits
|
||||
// but not more than what is left on a_norm or what is left to
|
||||
// fully populate the current limb of res.
|
||||
let a_take: usize = a_base2k.min(a_take_left).min(res_acc_left);
|
||||
|
||||
// Maximum relevant precision of res
|
||||
let res_prec: usize = res_min_size * res_base2k;
|
||||
|
||||
// Res limb index
|
||||
let mut res_idx: usize = res_min_size - 1;
|
||||
|
||||
// Trackers: wow much of res is left to be populated
|
||||
// for the current limb.
|
||||
let mut res_left: usize = res_base2k;
|
||||
|
||||
for j in 0..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
|
||||
for j in (0..a_min_size).rev() {
|
||||
// Trackers: wow much of a_norm is left to
|
||||
// be flushed on res.
|
||||
let mut a_left: usize = a_base2k;
|
||||
|
||||
// Normalizes the j-th limb of a and store the results into a_norm.
|
||||
// This step is required to avoid overflow in the next step,
|
||||
// which assumes that |a| is bounded by 2^{a_base2k -1}.
|
||||
if j != 0 {
|
||||
ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
|
||||
if a_take != 0 {
|
||||
// Extract `a_take` bits from a_norm and accumulates them on `res_slice`.
|
||||
let scale: usize = res_base2k - res_acc_left;
|
||||
ZNXARI::znx_extract_digit_addmul(a_take, scale, res_slice, a_norm);
|
||||
a_take_left -= a_take;
|
||||
res_acc_left -= a_take;
|
||||
}
|
||||
|
||||
// In the first iteration we need to match the precision of the input/output.
|
||||
// If a_min_size * a_base2k > res_min_size * res_base2k
|
||||
// then divround a_norm by the difference of precision and
|
||||
// acts like if a_norm has already been partially consummed.
|
||||
// Else acts like if res has been already populated
|
||||
// by the difference.
|
||||
if j == a_min_size - 1 {
|
||||
if a_prec > res_prec {
|
||||
ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm);
|
||||
a_left -= a_prec - res_prec;
|
||||
} else if res_prec > a_prec {
|
||||
res_left -= res_prec - a_prec;
|
||||
}
|
||||
}
|
||||
// If either:
|
||||
// * At least `res_base2k` bits have been accumulated
|
||||
// * We have reached the last limb of a
|
||||
// Then: Flushes them onto res
|
||||
if res_acc_left == 0 || a_limb == 0 {
|
||||
// This case happens only if `res_offset` is negative.
|
||||
// If `res_offset` is negative, we need to apply the offset BEFORE
|
||||
// the normalization to ensure the `res-offset` overflowing bits of `a`
|
||||
// are in the MSB of `res` instead of being discarded.
|
||||
if a_limb == 0 && a_take_left == 0 {
|
||||
// TODO: prove no overflow can happen here (should not intuitively)
|
||||
ZNXARI::znx_add_inplace(a_carry, a_norm);
|
||||
|
||||
// Flushes a into res
|
||||
loop {
|
||||
// Selects the maximum amount of a that can be flushed
|
||||
let a_take: usize = a_base2k.min(a_left).min(res_left);
|
||||
|
||||
// Output limb
|
||||
let res_slice: &mut [i64] = res.at_mut(res_col, res_idx);
|
||||
|
||||
// Scaling of the value to flush
|
||||
let lsh: usize = res_base2k - res_left;
|
||||
|
||||
// Extract the bits to flush on the output and updates
|
||||
// a_norm accordingly.
|
||||
ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm);
|
||||
|
||||
// Updates the trackers
|
||||
a_left -= a_take;
|
||||
res_left -= a_take;
|
||||
|
||||
// If the current limb of res is full,
|
||||
// then normalizes this limb and adds
|
||||
// the carry on a_norm.
|
||||
if res_left == 0 {
|
||||
// Updates tracker
|
||||
res_left += res_base2k;
|
||||
|
||||
// Normalizes res and propagates the carry on a.
|
||||
ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm);
|
||||
|
||||
// If we reached the last limb of res breaks,
|
||||
// but we might rerun the above loop if the
|
||||
// base2k of a is much smaller than the base2k
|
||||
// of res.
|
||||
if res_idx == 0 {
|
||||
ZNXARI::znx_add_inplace(carry, a_norm);
|
||||
break;
|
||||
// Usual case where for example
|
||||
// a: [ overflow ][x x x x x][x x x x x][x x x x x][x x x x x]
|
||||
// res: [x x x x x x][x x x x x x][x x x x x x][x x x x x x]
|
||||
//
|
||||
// where [overflow] are the overflowing bits of `a` (note that they are not a limb, but
|
||||
// stored in a[0] & carry from a[1]) that are moved into the MSB of `res` due to the
|
||||
// negative offset.
|
||||
//
|
||||
// In this case we populate what is left of `res_acc_left` using `a_carry`
|
||||
//
|
||||
// TODO: see if this can be simplified (e.g. just add).
|
||||
if res_acc_left != 0 {
|
||||
let scale: usize = res_base2k - res_acc_left;
|
||||
ZNXARI::znx_extract_digit_addmul(res_acc_left, scale, res_slice, a_carry);
|
||||
}
|
||||
|
||||
// Else updates the limb index of res.
|
||||
res_idx -= 1
|
||||
ZNXARI::znx_normalize_middle_step_inplace(res_base2k, 0, res_slice, res_carry);
|
||||
|
||||
// Previous step might not consume all bits of a_carry
|
||||
// TODO: prove no overflow can happen here
|
||||
ZNXARI::znx_add_inplace(res_carry, a_carry);
|
||||
|
||||
// We are done, so breaks out of the loop (yes we are at a[0], but
|
||||
// this avoids possible over/under flows of tracking variables)
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
// If a_norm is exhausted, breaks the loop.
|
||||
if a_left == 0 {
|
||||
ZNXARI::znx_add_inplace(carry, a_norm);
|
||||
break;
|
||||
// If we reached the last limb of res
|
||||
if res_limb == 0 {
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
res_acc_left += res_base2k;
|
||||
res_limb -= 1;
|
||||
}
|
||||
|
||||
// If a_norm is exhausted, breaks the inner loop.
|
||||
if a_take_left == 0 {
|
||||
ZNXARI::znx_add_inplace(a_carry, a_norm);
|
||||
break 'inner;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This case will happen if offset is negative.
|
||||
if res_end != 0 {
|
||||
// If there are no overlapping limbs between `res` and `a`
|
||||
// (can happen if offset is negative), then we propagate the
|
||||
// carry of `a` on res. Note that the carry of `a` can be
|
||||
// greater than the precision of res.
|
||||
//
|
||||
// For example with offset = -8:
|
||||
// a carry a[0] a[1] a[2] a[3]
|
||||
// a: [---------------------- ][x x x][x x x][x x x][x x x]
|
||||
// b: [x x x x][x x x x ]
|
||||
// res[0] res[1]
|
||||
//
|
||||
// If there are overlapping limbs between `res` and `a`,
|
||||
// we can use `res_carry`, which contains the carry of propagating
|
||||
// the shifted reconstruction of `a` in `res_base2k` along with
|
||||
// the carry of a[0].
|
||||
let carry_to_use = if a_start == a_end { a_carry } else { res_carry };
|
||||
|
||||
for j in 0..res_end {
|
||||
if j == res_end - 1 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(res_base2k, 0, res.at_mut(res_col, res_end - j - 1), carry_to_use);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(res_base2k, 0, res.at_mut(res_col, res_end - j - 1), carry_to_use);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -229,6 +427,191 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_normalize_cross_base2k() {
|
||||
let n: usize = 8;
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; vec_znx_normalize_tmp_bytes(n) / size_of::<i64>()];
|
||||
|
||||
use crate::reference::znx::ZnxRef;
|
||||
use rug::ops::SubAssignRound;
|
||||
use rug::{Float, float::Round};
|
||||
|
||||
let prec: usize = 128;
|
||||
|
||||
for in_base2k in 1..=51 {
|
||||
for out_base2k in 1..=51 {
|
||||
for offset in [
|
||||
-(prec as i64),
|
||||
-(prec as i64 - 1),
|
||||
-(prec as i64 - in_base2k as i64),
|
||||
-(in_base2k as i64 + 1),
|
||||
in_base2k as i64,
|
||||
-(in_base2k as i64 - 1),
|
||||
0,
|
||||
(in_base2k as i64 - 1),
|
||||
in_base2k as i64,
|
||||
(in_base2k as i64 + 1),
|
||||
(prec as i64 - in_base2k as i64),
|
||||
(prec - 1) as i64,
|
||||
prec as i64,
|
||||
] {
|
||||
let mut source: Source = Source::new([1u8; 32]);
|
||||
|
||||
let in_size: usize = prec.div_ceil(in_base2k);
|
||||
let in_prec: u32 = (in_size * in_base2k) as u32;
|
||||
|
||||
// Ensures no loss of precision (mostly for testing purpose)
|
||||
let out_size: usize = (in_prec as usize).div_ceil(out_base2k);
|
||||
|
||||
let out_prec: u32 = (out_size * out_base2k) as u32;
|
||||
let min_prec: u32 = (in_size * in_base2k).min(out_size * out_base2k) as u32;
|
||||
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, in_size);
|
||||
want.fill_uniform(60, &mut source);
|
||||
|
||||
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, out_size);
|
||||
have.fill_uniform(60, &mut source);
|
||||
vec_znx_normalize_cross_base2k::<_, _, ZnxRef>(&mut have, out_base2k, offset, 0, &want, in_base2k, 0, &mut carry);
|
||||
|
||||
let mut data_have: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec + 60, 0)).collect();
|
||||
let mut data_want: Vec<Float> = (0..n).map(|_| Float::with_val(in_prec + 60, 0)).collect();
|
||||
|
||||
have.decode_vec_float(out_base2k, 0, &mut data_have);
|
||||
want.decode_vec_float(in_base2k, 0, &mut data_want);
|
||||
|
||||
let scale: Float = Float::with_val(out_prec + 60, Float::u_pow_u(2, offset.unsigned_abs() as u32));
|
||||
|
||||
if offset > 0 {
|
||||
for x in &mut data_want {
|
||||
*x *= &scale;
|
||||
*x %= 1;
|
||||
}
|
||||
} else if offset < 0 {
|
||||
for x in &mut data_want {
|
||||
*x /= &scale;
|
||||
*x %= 1;
|
||||
}
|
||||
} else {
|
||||
for x in &mut data_want {
|
||||
*x %= 1;
|
||||
}
|
||||
}
|
||||
|
||||
for x in &mut data_have {
|
||||
if *x >= 0.5 {
|
||||
*x -= 1;
|
||||
} else if *x < -0.5 {
|
||||
*x += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for x in &mut data_want {
|
||||
if *x >= 0.5 {
|
||||
*x -= 1;
|
||||
} else if *x < -0.5 {
|
||||
*x += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
//println!("i:{i:02} {} {}", data_want[i], data_have[i]);
|
||||
|
||||
let mut err: Float = data_have[i].clone();
|
||||
err.sub_assign_round(&data_want[i], Round::Nearest);
|
||||
err = err.abs();
|
||||
|
||||
let err_log2: f64 = err.clone().max(&Float::with_val(prec as u32, 1e-60)).log2().to_f64();
|
||||
|
||||
assert!(err_log2 <= -(min_prec as f64) + 1.0, "{} {}", err_log2, -(min_prec as f64))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_normalize_inter_base2k() {
|
||||
let n: usize = 8;
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; vec_znx_normalize_tmp_bytes(n) / size_of::<i64>()];
|
||||
|
||||
use crate::reference::znx::ZnxRef;
|
||||
use rug::ops::SubAssignRound;
|
||||
use rug::{Float, float::Round};
|
||||
|
||||
let mut source: Source = Source::new([1u8; 32]);
|
||||
|
||||
let prec: usize = 128;
|
||||
let offset_range: i64 = prec as i64;
|
||||
|
||||
for base2k in 1..=51 {
|
||||
for offset in (-offset_range..=offset_range).step_by(base2k + 1) {
|
||||
let size: usize = prec.div_ceil(base2k);
|
||||
let out_prec: u32 = (size * base2k) as u32;
|
||||
|
||||
// Fills "want" with uniform values
|
||||
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, size);
|
||||
want.fill_uniform(60, &mut source);
|
||||
|
||||
// Fills "have" with the shifted normalization of "want"
|
||||
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, size);
|
||||
have.fill_uniform(60, &mut source);
|
||||
vec_znx_normalize_inter_base2k::<_, _, ZnxRef>(base2k, &mut have, offset, 0, &want, 0, &mut carry);
|
||||
|
||||
let mut data_have: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec + 60, 0)).collect();
|
||||
let mut data_want: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec + 60, 0)).collect();
|
||||
|
||||
have.decode_vec_float(base2k, 0, &mut data_have);
|
||||
want.decode_vec_float(base2k, 0, &mut data_want);
|
||||
|
||||
let scale: Float = Float::with_val(out_prec + 60, Float::u_pow_u(2, offset.unsigned_abs() as u32));
|
||||
|
||||
if offset > 0 {
|
||||
for x in &mut data_want {
|
||||
*x *= &scale;
|
||||
*x %= 1;
|
||||
}
|
||||
} else if offset < 0 {
|
||||
for x in &mut data_want {
|
||||
*x /= &scale;
|
||||
*x %= 1;
|
||||
}
|
||||
} else {
|
||||
for x in &mut data_want {
|
||||
*x %= 1;
|
||||
}
|
||||
}
|
||||
|
||||
for x in &mut data_have {
|
||||
if *x >= 0.5 {
|
||||
*x -= 1;
|
||||
} else if *x < -0.5 {
|
||||
*x += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for x in &mut data_want {
|
||||
if *x >= 0.5 {
|
||||
*x -= 1;
|
||||
} else if *x < -0.5 {
|
||||
*x += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
//println!("i:{i:02} {} {}", data_want[i], data_have[i]);
|
||||
|
||||
let mut err: Float = data_have[i].clone();
|
||||
err.sub_assign_round(&data_want[i], Round::Nearest);
|
||||
err = err.abs();
|
||||
|
||||
let err_log2: f64 = err.clone().max(&Float::with_val(prec as u32, 1e-60)).log2().to_f64();
|
||||
|
||||
assert!(err_log2 <= -(out_prec as f64), "{} {}", err_log2, -(out_prec as f64))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
@@ -261,10 +644,10 @@ where
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
let res_offset: i64 = 0;
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
|
||||
module.vec_znx_normalize(&mut res, base2k, res_offset, i, &a, base2k, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
@@ -326,71 +709,3 @@ where
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_normalize_conv() {
|
||||
let n: usize = 8;
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; 2 * n];
|
||||
|
||||
use crate::reference::znx::ZnxRef;
|
||||
use rug::ops::SubAssignRound;
|
||||
use rug::{Float, float::Round};
|
||||
|
||||
let mut source: Source = Source::new([1u8; 32]);
|
||||
|
||||
let prec: usize = 128;
|
||||
|
||||
let mut data: Vec<i128> = vec![0i128; n];
|
||||
|
||||
data.iter_mut().for_each(|x| *x = source.next_i128());
|
||||
|
||||
for start_base2k in 1..50 {
|
||||
for end_base2k in 1..50 {
|
||||
let end_size: usize = prec.div_ceil(end_base2k);
|
||||
|
||||
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
|
||||
want.encode_vec_i128(end_base2k, 0, prec, &data);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry);
|
||||
|
||||
// Creates a temporary poly where encoding is in start_base2k
|
||||
let mut tmp: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k));
|
||||
tmp.encode_vec_i128(start_base2k, 0, prec, &data);
|
||||
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry);
|
||||
|
||||
let mut data_tmp: Vec<Float> = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect();
|
||||
tmp.decode_vec_float(start_base2k, 0, &mut data_tmp);
|
||||
|
||||
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
|
||||
vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry);
|
||||
|
||||
let out_prec: u32 = (end_size * end_base2k) as u32;
|
||||
|
||||
let mut data_want: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec, 0)).collect();
|
||||
let mut data_res: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec, 0)).collect();
|
||||
|
||||
have.decode_vec_float(end_base2k, 0, &mut data_want);
|
||||
want.decode_vec_float(end_base2k, 0, &mut data_res);
|
||||
|
||||
for i in 0..n {
|
||||
let mut err: Float = data_want[i].clone();
|
||||
err.sub_assign_round(&data_res[i], Round::Nearest);
|
||||
err = err.abs();
|
||||
|
||||
let err_log2: f64 = err
|
||||
.clone()
|
||||
.max(&Float::with_val(prec as u32, 1e-60))
|
||||
.log2()
|
||||
.to_f64();
|
||||
|
||||
assert!(
|
||||
err_log2 <= -(out_prec as f64) + 1.,
|
||||
"{} {}",
|
||||
err_log2,
|
||||
-(out_prec as f64) + 1.
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,12 +34,7 @@ pub fn vec_znx_fill_normal_ref<R>(
|
||||
|
||||
let limb: usize = k.div_ceil(base2k) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
|
||||
znx_fill_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
znx_fill_normal_f64_ref(res.at_mut(res_col, limb), sigma * scale, bound * scale, source)
|
||||
}
|
||||
|
||||
pub fn vec_znx_add_normal_ref<R>(
|
||||
@@ -62,10 +57,5 @@ pub fn vec_znx_add_normal_ref<R>(
|
||||
|
||||
let limb: usize = k.div_ceil(base2k) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
znx_add_normal_f64_ref(res.at_mut(res_col, limb), sigma * scale, bound * scale, source)
|
||||
}
|
||||
|
||||
@@ -5,13 +5,10 @@ use criterion::{BenchmarkId, Criterion};
|
||||
use crate::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxLsh, VecZnxLshInplace, VecZnxRsh, VecZnxRshInplace},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::{
|
||||
vec_znx::vec_znx_copy,
|
||||
znx::{
|
||||
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxZero,
|
||||
},
|
||||
reference::znx::{
|
||||
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxZero,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
@@ -54,10 +51,7 @@ where
|
||||
|
||||
(0..size - steps).for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * (j + steps));
|
||||
ZNXARI::znx_copy(
|
||||
&mut lhs[start + j * slice_size..end + j * slice_size],
|
||||
&rhs[start..end],
|
||||
);
|
||||
ZNXARI::znx_copy(&mut lhs[start + j * slice_size..end + j * slice_size], &rhs[start..end]);
|
||||
});
|
||||
|
||||
for j in size - steps..size {
|
||||
@@ -65,16 +59,13 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// Inplace normalization with left shift of k % base2k
|
||||
if !k.is_multiple_of(base2k) {
|
||||
for j in (0..size - steps).rev() {
|
||||
if j == size - steps - 1 {
|
||||
ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
for j in (0..size - steps).rev() {
|
||||
if j == size - steps - 1 {
|
||||
ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,38 +95,13 @@ where
|
||||
|
||||
// Simply a left shifted normalization of limbs
|
||||
// by k/base2k and intra-limb by base2k - k%base2k
|
||||
if !k.is_multiple_of(base2k) {
|
||||
for j in (0..min_size).rev() {
|
||||
if j == min_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(
|
||||
base2k,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
carry,
|
||||
);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(
|
||||
base2k,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
carry,
|
||||
);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
base2k,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
carry,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If k % base2k = 0, then this is simply a copy.
|
||||
for j in (0..min_size).rev() {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
|
||||
for j in (0..min_size).rev() {
|
||||
if j == min_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,10 +112,10 @@ where
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
2 * n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxZero
|
||||
@@ -163,76 +129,48 @@ where
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let n: usize = res.n();
|
||||
let cols: usize = res.cols();
|
||||
|
||||
let size: usize = res.size();
|
||||
|
||||
let mut steps: usize = k / base2k;
|
||||
let k_rem: usize = k % base2k;
|
||||
|
||||
if k == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if steps >= size {
|
||||
for j in 0..size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let start: usize = n * res_col;
|
||||
let end: usize = start + n;
|
||||
let slice_size: usize = n * cols;
|
||||
|
||||
if !k.is_multiple_of(base2k) {
|
||||
// We rsh by an additional base2k and then lsh by base2k-k
|
||||
// Allows to re-use efficient normalization code, avoids
|
||||
// avoids overflows & produce output that is normalized
|
||||
steps += 1;
|
||||
}
|
||||
|
||||
// All limbs of a that would fall outside of the limbs of res are discarded,
|
||||
// but the carry still need to be computed.
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
if j == size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
|
||||
}
|
||||
});
|
||||
let (carry, tmp) = tmp[..2 * n].split_at_mut(n);
|
||||
|
||||
// Continues with shifted normalization
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
(steps..size).rev().for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
|
||||
let rhs_slice: &mut [i64] = &mut rhs[start..end];
|
||||
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
|
||||
ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry);
|
||||
});
|
||||
let lsh: usize = (base2k - k_rem) % base2k;
|
||||
|
||||
// Propagates carry on the rest of the limbs of res
|
||||
for j in (0..steps).rev() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
// All limbs of a that would fall outside of the limbs of res are discarded,
|
||||
// but the carry still need to be computed.
|
||||
for j in 0..steps {
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh, res.at(res_col, size - j - 1), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh, res.at(res_col, size - j - 1), carry);
|
||||
}
|
||||
} else {
|
||||
// Shift by multiples of base2k
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
(steps..size).rev().for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
|
||||
ZNXARI::znx_copy(
|
||||
&mut rhs[start..end],
|
||||
&lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end],
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// Zeroes the top
|
||||
(0..steps).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
// Continues with shifted normalization
|
||||
for j in 0..size - steps {
|
||||
ZNXARI::znx_copy(tmp, res.at(res_col, size - steps - j - 1));
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh, tmp, carry);
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, size - j - 1), tmp);
|
||||
}
|
||||
|
||||
// Propagates carry on the rest of the limbs of res
|
||||
for j in 0..steps {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, lsh, res.at_mut(res_col, steps - j - 1), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh, res.at_mut(res_col, steps - j - 1), carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,90 +197,59 @@ where
|
||||
let mut steps: usize = k / base2k;
|
||||
let k_rem: usize = k % base2k;
|
||||
|
||||
if k == 0 {
|
||||
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
|
||||
return;
|
||||
}
|
||||
|
||||
if steps >= res_size {
|
||||
for j in 0..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if !k.is_multiple_of(base2k) {
|
||||
// We rsh by an additional base2k and then lsh by base2k-k
|
||||
// Allows to re-use efficient normalization code, avoids
|
||||
// avoids overflows & produce output that is normalized
|
||||
steps += 1;
|
||||
}
|
||||
|
||||
// All limbs of a that are moved outside of the limbs of res are discarded,
|
||||
// but the carry still need to be computed.
|
||||
for j in (res_size..a_size + steps).rev() {
|
||||
if j == a_size + steps - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
|
||||
}
|
||||
let lsh: usize = (base2k - k_rem) % base2k; // 0 if k | base2k
|
||||
let res_end: usize = res_size.min(steps);
|
||||
let res_start: usize = res_size.min(a_size + steps);
|
||||
let a_start: usize = a_size.min(res_size.saturating_sub(steps));
|
||||
|
||||
// All limbs of a that are moved outside of the limbs of res are discarded,
|
||||
// but the carry still need to be computed.
|
||||
let a_out_range: usize = a_size.saturating_sub(a_start);
|
||||
|
||||
for j in 0..a_out_range {
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh, a.at(a_col, a_size - j - 1), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh, a.at(a_col, a_size - j - 1), carry);
|
||||
}
|
||||
}
|
||||
|
||||
// Avoids over flow of limbs of res
|
||||
let min_size: usize = res_size.min(a_size + steps);
|
||||
if a_out_range == 0 {
|
||||
ZNXARI::znx_zero(carry);
|
||||
}
|
||||
|
||||
// Zeroes lower limbs of res if a_size + steps < res_size
|
||||
(min_size..res_size).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
// Zeroes lower limbs of res if a_size + steps < res_size
|
||||
for j in 0..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
|
||||
// Continues with shifted normalization
|
||||
for j in (steps..min_size).rev() {
|
||||
// Case if no limb of a was previously discarded
|
||||
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(
|
||||
base2k,
|
||||
base2k - k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j - steps),
|
||||
carry,
|
||||
);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
base2k,
|
||||
base2k - k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j - steps),
|
||||
carry,
|
||||
);
|
||||
}
|
||||
// Continues with shifted normalization
|
||||
let mid_range: usize = res_start.saturating_sub(res_end);
|
||||
|
||||
for j in 0..mid_range {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
base2k,
|
||||
lsh,
|
||||
res.at_mut(res_col, res_start - j - 1),
|
||||
a.at(a_col, a_start - j - 1),
|
||||
carry,
|
||||
);
|
||||
}
|
||||
|
||||
// Propagates carry on the rest of the limbs of res
|
||||
for j in 0..res_end {
|
||||
if j == res_end - 1 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, lsh, res.at_mut(res_col, res_end - j - 1), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh, res.at_mut(res_col, res_end - j - 1), carry);
|
||||
}
|
||||
|
||||
// Propagates carry on the rest of the limbs of res
|
||||
for j in (0..steps).rev() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let min_size: usize = res_size.min(a_size + steps);
|
||||
|
||||
// Zeroes the top
|
||||
(0..steps).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
|
||||
// Shift a into res, up to the maximum
|
||||
for j in (steps..min_size).rev() {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j - steps));
|
||||
}
|
||||
|
||||
// Zeroes bottom if a_size + steps < res_size
|
||||
(min_size..res_size).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,7 +280,7 @@ where
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(vec_znx_lsh_tmp_bytes(n));
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
@@ -423,7 +330,7 @@ where
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(vec_znx_lsh_tmp_bytes(n));
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
@@ -473,7 +380,7 @@ where
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(vec_znx_rsh_tmp_bytes(n));
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
@@ -523,7 +430,7 @@ where
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(vec_znx_rsh_tmp_bytes(n));
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
@@ -552,8 +459,8 @@ mod tests {
|
||||
layouts::{FillUniform, VecZnx, ZnxView},
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
|
||||
vec_znx_sub_inplace,
|
||||
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_normalize_inplace, vec_znx_rsh,
|
||||
vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_sub_inplace,
|
||||
},
|
||||
znx::ZnxRef,
|
||||
},
|
||||
@@ -572,7 +479,7 @@ mod tests {
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; n];
|
||||
let mut carry: Vec<i64> = vec![0i64; vec_znx_lsh_tmp_bytes(n) / size_of::<i64>()];
|
||||
|
||||
let base2k: usize = 50;
|
||||
|
||||
@@ -604,7 +511,7 @@ mod tests {
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; n];
|
||||
let mut carry: Vec<i64> = vec![0i64; vec_znx_rsh_tmp_bytes(n) / size_of::<i64>()];
|
||||
|
||||
let base2k: usize = 50;
|
||||
|
||||
|
||||
@@ -22,10 +22,7 @@ where
|
||||
{
|
||||
assert_eq!(tmp.len(), a.n());
|
||||
|
||||
assert!(
|
||||
_n_out < _n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
assert!(_n_out < _n_in, "invalid a: output ring degree should be smaller");
|
||||
|
||||
res[1..].iter_mut().for_each(|bi| {
|
||||
assert_eq!(
|
||||
|
||||
@@ -12,10 +12,6 @@ pub fn znx_automorphism_ref(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
res[0] = a[0];
|
||||
for ai in a.iter().take(n).skip(1) {
|
||||
k = (k + p_2n) & mask;
|
||||
if k < n {
|
||||
res[k] = *ai
|
||||
} else {
|
||||
res[k - n] = -*ai
|
||||
}
|
||||
if k < n { res[k] = *ai } else { res[k - n] = -*ai }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,9 +33,9 @@ pub fn znx_normalize_first_step_carry_only_ref(base2k: usize, lsh: usize, x: &[i
|
||||
*c = get_carry_i64(base2k, *x, get_digit_i64(base2k, *x));
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = base2k - lsh;
|
||||
let base2k_lsh: usize = base2k - lsh;
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*c = get_carry_i64(basek_lsh, *x, get_digit_i64(basek_lsh, *x));
|
||||
*c = get_carry_i64(base2k_lsh, *x, get_digit_i64(base2k_lsh, *x));
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -55,10 +55,10 @@ pub fn znx_normalize_first_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [
|
||||
*x = digit;
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = base2k - lsh;
|
||||
let base2k_lsh: usize = base2k - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit_i64(basek_lsh, *x);
|
||||
*c = get_carry_i64(basek_lsh, *x, digit);
|
||||
let digit: i64 = get_digit_i64(base2k_lsh, *x);
|
||||
*c = get_carry_i64(base2k_lsh, *x, digit);
|
||||
*x = digit << lsh;
|
||||
});
|
||||
}
|
||||
@@ -80,10 +80,10 @@ pub fn znx_normalize_first_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a:
|
||||
*x = digit;
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = base2k - lsh;
|
||||
let base2k_lsh: usize = base2k - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit_i64(basek_lsh, *a);
|
||||
*c = get_carry_i64(basek_lsh, *a, digit);
|
||||
let digit: i64 = get_digit_i64(base2k_lsh, *a);
|
||||
*c = get_carry_i64(base2k_lsh, *a, digit);
|
||||
*x = digit << lsh;
|
||||
});
|
||||
}
|
||||
@@ -104,10 +104,10 @@ pub fn znx_normalize_middle_step_carry_only_ref(base2k: usize, lsh: usize, x: &[
|
||||
*c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c));
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = base2k - lsh;
|
||||
let base2k_lsh: usize = base2k - lsh;
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit_i64(basek_lsh, *x);
|
||||
let carry: i64 = get_carry_i64(basek_lsh, *x, digit);
|
||||
let digit: i64 = get_digit_i64(base2k_lsh, *x);
|
||||
let carry: i64 = get_carry_i64(base2k_lsh, *x, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c));
|
||||
});
|
||||
@@ -131,10 +131,10 @@ pub fn znx_normalize_middle_step_inplace_ref(base2k: usize, lsh: usize, x: &mut
|
||||
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = base2k - lsh;
|
||||
let base2k_lsh: usize = base2k - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit_i64(basek_lsh, *x);
|
||||
let carry: i64 = get_carry_i64(basek_lsh, *x, digit);
|
||||
let digit: i64 = get_digit_i64(base2k_lsh, *x);
|
||||
let carry: i64 = get_carry_i64(base2k_lsh, *x, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*x = get_digit_i64(base2k, digit_plus_c);
|
||||
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
|
||||
@@ -178,10 +178,10 @@ pub fn znx_normalize_middle_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a
|
||||
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = base2k - lsh;
|
||||
let base2k_lsh: usize = base2k - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit_i64(basek_lsh, *a);
|
||||
let carry: i64 = get_carry_i64(basek_lsh, *a, digit);
|
||||
let digit: i64 = get_digit_i64(base2k_lsh, *a);
|
||||
let carry: i64 = get_carry_i64(base2k_lsh, *a, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*x = get_digit_i64(base2k, digit_plus_c);
|
||||
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
|
||||
@@ -202,9 +202,9 @@ pub fn znx_normalize_final_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [
|
||||
*x = get_digit_i64(base2k, get_digit_i64(base2k, *x) + *c);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = base2k - lsh;
|
||||
let base2k_lsh: usize = base2k - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *x) << lsh) + *c);
|
||||
*x = get_digit_i64(base2k, (get_digit_i64(base2k_lsh, *x) << lsh) + *c);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -221,9 +221,9 @@ pub fn znx_normalize_final_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a:
|
||||
*x = get_digit_i64(base2k, get_digit_i64(base2k, *a) + *c);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = base2k - lsh;
|
||||
let base2k_lsh: usize = base2k - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
*x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *a) << lsh) + *c);
|
||||
*x = get_digit_i64(base2k, (get_digit_i64(base2k_lsh, *a) << lsh) + *c);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,80 @@
|
||||
use rand::RngCore;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
|
||||
VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
|
||||
CnvPVecAlloc, Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxAdd,
|
||||
VecZnxBigAlloc, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA,
|
||||
VecZnxNormalizeInplace,
|
||||
},
|
||||
layouts::{
|
||||
Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
Backend, CnvPVecL, CnvPVecR, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef,
|
||||
ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn test_bivariate_tensoring<M, BE: Backend>(module: &M)
|
||||
pub fn test_convolution_by_const<M, BE: Backend>(module: &M)
|
||||
where
|
||||
M: ModuleN + Convolution<BE> + VecZnxBigNormalize<BE> + VecZnxNormalizeInplace<BE> + VecZnxBigAlloc<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
let a_cols: usize = 2;
|
||||
let a_size: usize = 15;
|
||||
let b_size: usize = 15;
|
||||
let res_size: usize = a_size + b_size;
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, b_size);
|
||||
|
||||
let mut res_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
|
||||
let mut res_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
|
||||
let mut res_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, res_size);
|
||||
|
||||
a.fill_uniform(base2k, &mut source);
|
||||
|
||||
let mut b_const = vec![0i64; b_size];
|
||||
let mask = (1 << base2k) - 1;
|
||||
for (j, x) in b_const[..1].iter_mut().enumerate() {
|
||||
let r = source.next_u64() & mask;
|
||||
*x = ((r << (64 - base2k)) as i64) >> (64 - base2k);
|
||||
b.at_mut(0, j)[0] = *x
|
||||
}
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.cnv_by_const_apply_tmp_bytes(res_size, 0, a_size, b_size));
|
||||
|
||||
for a_col in 0..a.cols() {
|
||||
for offset in 0..res_size {
|
||||
module.cnv_by_const_apply(&mut res_big, offset, 0, &a, a_col, &b_const, scratch.borrow());
|
||||
module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow());
|
||||
|
||||
bivariate_convolution_naive(
|
||||
module,
|
||||
base2k,
|
||||
(offset + 1) as i64,
|
||||
&mut res_want,
|
||||
0,
|
||||
&a,
|
||||
a_col,
|
||||
&b,
|
||||
0,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
assert_eq!(res_want, res_have);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_convolution<M, BE: Backend>(module: &M)
|
||||
where
|
||||
M: ModuleN
|
||||
+ BivariateTensoring<BE>
|
||||
+ Convolution<BE>
|
||||
+ CnvPVecAlloc<BE>
|
||||
+ VecZnxDftAlloc<BE>
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VecZnxIdftApplyTmpA<BE>
|
||||
@@ -27,56 +88,199 @@ where
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
let a_cols: usize = 3;
|
||||
let b_cols: usize = 3;
|
||||
let a_size: usize = 3;
|
||||
let b_size: usize = 3;
|
||||
let c_cols: usize = a_cols + b_cols - 1;
|
||||
let c_size: usize = a_size + b_size;
|
||||
let a_cols: usize = 2;
|
||||
let b_cols: usize = 2;
|
||||
let a_size: usize = 15;
|
||||
let b_size: usize = 15;
|
||||
let res_size: usize = a_size + b_size;
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), b_cols, b_size);
|
||||
|
||||
let mut c_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
|
||||
let mut c_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
|
||||
let mut c_have_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(c_cols, c_size);
|
||||
let mut c_have_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(c_cols, c_size);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.convolution_tmp_bytes(b_size));
|
||||
let mut res_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
|
||||
let mut res_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
|
||||
let mut res_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(1, res_size);
|
||||
let mut res_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, res_size);
|
||||
|
||||
a.fill_uniform(base2k, &mut source);
|
||||
b.fill_uniform(base2k, &mut source);
|
||||
|
||||
let mut b_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(b_cols, b_size);
|
||||
for i in 0..b.cols() {
|
||||
module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i);
|
||||
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(a_cols, a_size);
|
||||
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(b_cols, b_size);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
|
||||
module
|
||||
.cnv_apply_dft_tmp_bytes(res_size, 0, a_size, b_size)
|
||||
.max(module.cnv_prepare_left_tmp_bytes(res_size, a_size))
|
||||
.max(module.cnv_prepare_right_tmp_bytes(res_size, b_size)),
|
||||
);
|
||||
|
||||
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
|
||||
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
|
||||
|
||||
for a_col in 0..a.cols() {
|
||||
for b_col in 0..b.cols() {
|
||||
for offset in 0..res_size {
|
||||
module.cnv_apply_dft(&mut res_dft, offset, 0, &a_prep, a_col, &b_prep, b_col, scratch.borrow());
|
||||
|
||||
module.vec_znx_idft_apply_tmpa(&mut res_big, 0, &mut res_dft, 0);
|
||||
module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow());
|
||||
|
||||
bivariate_convolution_naive(
|
||||
module,
|
||||
base2k,
|
||||
(offset + 1) as i64,
|
||||
&mut res_want,
|
||||
0,
|
||||
&a,
|
||||
a_col,
|
||||
&b,
|
||||
b_col,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
assert_eq!(res_want, res_have);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_convolution_pairwise<M, BE: Backend>(module: &M)
|
||||
where
|
||||
M: ModuleN
|
||||
+ Convolution<BE>
|
||||
+ CnvPVecAlloc<BE>
|
||||
+ VecZnxDftAlloc<BE>
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VecZnxIdftApplyTmpA<BE>
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxNormalizeInplace<BE>
|
||||
+ VecZnxBigAlloc<BE>
|
||||
+ VecZnxAdd
|
||||
+ VecZnxCopy,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
|
||||
{
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let base2k: usize = 12;
|
||||
|
||||
let cols = 2;
|
||||
let a_size: usize = 15;
|
||||
let b_size: usize = 15;
|
||||
let res_size: usize = a_size + b_size;
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, a_size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, b_size);
|
||||
let mut tmp_a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, a_size);
|
||||
let mut tmp_b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, b_size);
|
||||
|
||||
let mut res_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
|
||||
let mut res_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
|
||||
let mut res_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(1, res_size);
|
||||
let mut res_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, res_size);
|
||||
|
||||
a.fill_uniform(base2k, &mut source);
|
||||
b.fill_uniform(base2k, &mut source);
|
||||
|
||||
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(cols, a_size);
|
||||
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(cols, b_size);
|
||||
|
||||
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
|
||||
module
|
||||
.cnv_pairwise_apply_dft_tmp_bytes(res_size, 0, a_size, b_size)
|
||||
.max(module.cnv_prepare_left_tmp_bytes(res_size, a_size))
|
||||
.max(module.cnv_prepare_right_tmp_bytes(res_size, b_size)),
|
||||
);
|
||||
|
||||
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
|
||||
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
|
||||
|
||||
for col_i in 0..cols {
|
||||
for col_j in 0..cols {
|
||||
for offset in 0..res_size {
|
||||
module.cnv_pairwise_apply_dft(&mut res_dft, offset, 0, &a_prep, &b_prep, col_i, col_j, scratch.borrow());
|
||||
|
||||
module.vec_znx_idft_apply_tmpa(&mut res_big, 0, &mut res_dft, 0);
|
||||
module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow());
|
||||
|
||||
if col_i != col_j {
|
||||
module.vec_znx_add(&mut tmp_a, 0, &a, col_i, &a, col_j);
|
||||
module.vec_znx_add(&mut tmp_b, 0, &b, col_i, &b, col_j);
|
||||
} else {
|
||||
module.vec_znx_copy(&mut tmp_a, 0, &a, col_i);
|
||||
module.vec_znx_copy(&mut tmp_b, 0, &b, col_j);
|
||||
}
|
||||
|
||||
bivariate_convolution_naive(
|
||||
module,
|
||||
base2k,
|
||||
(offset + 1) as i64,
|
||||
&mut res_want,
|
||||
0,
|
||||
&tmp_a,
|
||||
0,
|
||||
&tmp_b,
|
||||
0,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
assert_eq!(res_want, res_have);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn bivariate_convolution_naive<R, A, B, M, BE: Backend>(
|
||||
module: &M,
|
||||
base2k: usize,
|
||||
k: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
scratch: &mut Scratch<BE>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
M: VecZnxNormalizeInplace<BE>,
|
||||
Scratch<BE>: TakeSlice,
|
||||
{
|
||||
let res: &mut VecZnx<&mut [u8]> = &mut res.to_mut();
|
||||
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||
let b: &VecZnx<&[u8]> = &b.to_ref();
|
||||
|
||||
for j in 0..res.size() {
|
||||
res.zero_at(res_col, j);
|
||||
}
|
||||
|
||||
for mut k in 0..(2 * c_size + 1) as i64 {
|
||||
k -= c_size as i64;
|
||||
for a_limb in 0..a.size() {
|
||||
for b_limb in 0..b.size() {
|
||||
let res_scale_abs = k.unsigned_abs() as usize;
|
||||
|
||||
module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
|
||||
let mut res_limb: usize = a_limb + b_limb + 1;
|
||||
|
||||
for i in 0..c_cols {
|
||||
module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i);
|
||||
if k <= 0 {
|
||||
res_limb += res_scale_abs;
|
||||
|
||||
if res_limb < res.size() {
|
||||
negacyclic_convolution_naive_add(res.at_mut(res_col, res_limb), a.at(a_col, a_limb), b.at(b_col, b_limb));
|
||||
}
|
||||
} else if res_limb >= res_scale_abs {
|
||||
res_limb -= res_scale_abs;
|
||||
|
||||
if res_limb < res.size() {
|
||||
negacyclic_convolution_naive_add(res.at_mut(res_col, res_limb), a.at(a_col, a_limb), b.at(b_col, b_limb));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..c_cols {
|
||||
module.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut c_have,
|
||||
i,
|
||||
base2k,
|
||||
&c_have_big,
|
||||
i,
|
||||
scratch.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
|
||||
|
||||
assert_eq!(c_want, c_have);
|
||||
}
|
||||
|
||||
module.vec_znx_normalize_inplace(base2k, res, res_col, scratch);
|
||||
}
|
||||
|
||||
fn bivariate_tensoring_naive<R, A, B, M, BE: Backend>(
|
||||
@@ -154,3 +358,18 @@ fn negacyclic_convolution_naive_add(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn negacyclic_convolution_naive(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
let n: usize = res.len();
|
||||
res.fill(0);
|
||||
for i in 0..n {
|
||||
let ai: i64 = a[i];
|
||||
let lim: usize = n - i;
|
||||
for j in 0..lim {
|
||||
res[i + j] += ai * b[j];
|
||||
}
|
||||
for j in lim..n {
|
||||
res[i + j - n] -= ai * b[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,10 +29,7 @@ where
|
||||
receiver.read_from(&mut reader).expect("read_from failed");
|
||||
|
||||
// Ensure serialization round-trip correctness
|
||||
assert_eq!(
|
||||
&original, &receiver,
|
||||
"Deserialized object does not match the original"
|
||||
);
|
||||
assert_eq!(&original, &receiver, "Deserialized object does not match the original");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -90,24 +90,8 @@ where
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_ref,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_test,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
|
||||
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
@@ -212,24 +196,8 @@ where
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_ref,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_test,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
|
||||
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
@@ -339,24 +307,8 @@ where
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_ref,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_test,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
|
||||
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
@@ -447,24 +399,8 @@ pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_ref,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_test,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
|
||||
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
|
||||
@@ -7,8 +7,9 @@ use crate::{
|
||||
VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh, VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings,
|
||||
VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes,
|
||||
VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
|
||||
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing,
|
||||
VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace,
|
||||
VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::znx_copy_ref,
|
||||
@@ -341,10 +342,7 @@ where
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_merge_rings_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: [VecZnx<Vec<u8>>; 2] = [
|
||||
VecZnx::alloc(n >> 1, cols, a_size),
|
||||
VecZnx::alloc(n >> 1, cols, a_size),
|
||||
];
|
||||
let mut a: [VecZnx<Vec<u8>>; 2] = [VecZnx::alloc(n >> 1, cols, a_size), VecZnx::alloc(n >> 1, cols, a_size)];
|
||||
|
||||
a.iter_mut().for_each(|ai| {
|
||||
ai.fill_uniform(base2k, &mut source);
|
||||
@@ -549,26 +547,20 @@ where
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
// Set d to garbage
|
||||
res_ref.fill_uniform(base2k, &mut source);
|
||||
res_test.fill_uniform(base2k, &mut source);
|
||||
for res_offset in -(base2k as i64)..=(base2k as i64) {
|
||||
// Set d to garbage
|
||||
res_ref.fill_uniform(base2k, &mut source);
|
||||
res_test.fill_uniform(base2k, &mut source);
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_normalize(base2k, &mut res_ref, i, base2k, &a, i, scratch_ref.borrow());
|
||||
module_test.vec_znx_normalize(
|
||||
base2k,
|
||||
&mut res_test,
|
||||
i,
|
||||
base2k,
|
||||
&a,
|
||||
i,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_normalize(&mut res_ref, base2k, res_offset, i, &a, base2k, i, scratch_ref.borrow());
|
||||
module_test.vec_znx_normalize(&mut res_test, base2k, res_offset, i, &a, base2k, i, scratch_test.borrow());
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -718,10 +710,7 @@ where
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.stats(base2k, col_i).std();
|
||||
assert!(
|
||||
(std - one_12_sqrt).abs() < 0.01,
|
||||
"std={std} ~!= {one_12_sqrt}",
|
||||
);
|
||||
assert!((std - one_12_sqrt).abs() < 0.01, "std={std} ~!= {one_12_sqrt}",);
|
||||
}
|
||||
})
|
||||
});
|
||||
@@ -783,11 +772,7 @@ where
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
|
||||
assert!(
|
||||
(std - sigma * sqrt2).abs() < 0.1,
|
||||
"std={std} ~!= {}",
|
||||
sigma * sqrt2
|
||||
);
|
||||
assert!((std - sigma * sqrt2).abs() < 0.1, "std={std} ~!= {}", sigma * sqrt2);
|
||||
}
|
||||
})
|
||||
});
|
||||
@@ -872,9 +857,9 @@ where
|
||||
|
||||
pub fn test_vec_znx_rsh<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxRsh<BR> + VecZnxLshTmpBytes,
|
||||
Module<BR>: VecZnxRsh<BR> + VecZnxRshTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
Module<BT>: VecZnxRsh<BT> + VecZnxLshTmpBytes,
|
||||
Module<BT>: VecZnxRsh<BT> + VecZnxRshTmpBytes,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
@@ -882,8 +867,8 @@ where
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
|
||||
|
||||
for a_size in [1, 2, 3, 4] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
@@ -914,9 +899,9 @@ where
|
||||
|
||||
pub fn test_vec_znx_rsh_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
|
||||
where
|
||||
Module<BR>: VecZnxRshInplace<BR> + VecZnxLshTmpBytes,
|
||||
Module<BR>: VecZnxRshInplace<BR> + VecZnxRshTmpBytes,
|
||||
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
|
||||
Module<BT>: VecZnxRshInplace<BT> + VecZnxLshTmpBytes,
|
||||
Module<BT>: VecZnxRshInplace<BT> + VecZnxRshTmpBytes,
|
||||
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
|
||||
{
|
||||
assert_eq!(module_ref.n(), module_test.n());
|
||||
@@ -924,8 +909,8 @@ where
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
|
||||
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
|
||||
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
for k in 0..base2k * res_size {
|
||||
@@ -966,15 +951,11 @@ where
|
||||
let a_digest = a.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res_ref: [VecZnx<Vec<u8>>; 2] = [
|
||||
VecZnx::alloc(n >> 1, cols, res_size),
|
||||
VecZnx::alloc(n >> 1, cols, res_size),
|
||||
];
|
||||
let mut res_ref: [VecZnx<Vec<u8>>; 2] =
|
||||
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
|
||||
|
||||
let mut res_test: [VecZnx<Vec<u8>>; 2] = [
|
||||
VecZnx::alloc(n >> 1, cols, res_size),
|
||||
VecZnx::alloc(n >> 1, cols, res_size),
|
||||
];
|
||||
let mut res_test: [VecZnx<Vec<u8>>; 2] =
|
||||
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
|
||||
|
||||
res_ref.iter_mut().for_each(|ri| {
|
||||
ri.fill_uniform(base2k, &mut source);
|
||||
|
||||
@@ -93,20 +93,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -188,20 +190,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -279,20 +283,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -367,20 +373,22 @@ pub fn test_vec_znx_big_add_small_inplace<BR: Backend, BT: Backend>(
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -459,20 +467,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -546,20 +556,22 @@ pub fn test_vec_znx_big_automorphism_inplace<BR: Backend, BT: Backend>(
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -631,20 +643,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -709,20 +723,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -782,36 +798,40 @@ where
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
// Set d to garbage
|
||||
source.fill_bytes(res_ref.data_mut());
|
||||
source.fill_bytes(res_test.data_mut());
|
||||
for res_offset in -(base2k as i64)..=(base2k as i64) {
|
||||
// Set d to garbage
|
||||
source.fill_bytes(res_ref.data_mut());
|
||||
source.fill_bytes(res_test.data_mut());
|
||||
|
||||
// Reference
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_ref,
|
||||
j,
|
||||
base2k,
|
||||
&a_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_test,
|
||||
j,
|
||||
base2k,
|
||||
&a_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
// Reference
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
&mut res_ref,
|
||||
base2k,
|
||||
res_offset,
|
||||
j,
|
||||
&a_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
&mut res_test,
|
||||
base2k,
|
||||
res_offset,
|
||||
j,
|
||||
&a_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(a_ref.digest_u64(), a_ref_digest);
|
||||
assert_eq!(a_test.digest_u64(), a_test_digest);
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
|
||||
assert_eq!(a_ref.digest_u64(), a_ref_digest);
|
||||
assert_eq!(a_test.digest_u64(), a_test_digest);
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -891,20 +911,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -986,20 +1008,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -1083,20 +1107,22 @@ pub fn test_vec_znx_big_sub_negate_inplace<BR: Backend, BT: Backend>(
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -1180,20 +1206,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -1278,20 +1306,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -1366,20 +1396,22 @@ pub fn test_vec_znx_big_sub_small_a_inplace<BR: Backend, BT: Backend>(
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -1427,55 +1459,59 @@ pub fn test_vec_znx_big_sub_small_b_inplace<BR: Backend, BT: Backend>(
|
||||
let a_digest: u64 = a.digest_u64();
|
||||
|
||||
for res_size in [1, 2, 3, 4] {
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
res.fill_uniform(base2k, &mut source);
|
||||
for res_offset in -(base2k as i64)..=(base2k as i64) {
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
res.fill_uniform(base2k, &mut source);
|
||||
|
||||
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
|
||||
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
|
||||
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
|
||||
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j);
|
||||
module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j);
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j);
|
||||
module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j);
|
||||
}
|
||||
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i);
|
||||
module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
&mut res_small_ref,
|
||||
base2k,
|
||||
res_offset,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
&mut res_small_test,
|
||||
base2k,
|
||||
res_offset,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
|
||||
for i in 0..cols {
|
||||
module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i);
|
||||
module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i);
|
||||
}
|
||||
|
||||
assert_eq!(a.digest_u64(), a_digest);
|
||||
|
||||
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let res_ref_digest: u64 = res_big_ref.digest_u64();
|
||||
let res_test_digest: u64 = res_big_test.digest_u64();
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_ref,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
&res_big_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
|
||||
assert_eq!(res_big_test.digest_u64(), res_test_digest);
|
||||
|
||||
assert_eq!(res_small_ref, res_small_test);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,20 +102,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -208,20 +210,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -311,20 +315,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -392,13 +398,7 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_idft_apply(&mut res_big_ref, j, &res_dft_ref, j, scratch_ref.borrow());
|
||||
module_test.vec_znx_idft_apply(
|
||||
&mut res_big_test,
|
||||
j,
|
||||
&res_dft_test,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
module_test.vec_znx_idft_apply(&mut res_big_test, j, &res_dft_test, j, scratch_test.borrow());
|
||||
}
|
||||
|
||||
assert_eq!(res_dft_ref.digest_u64(), res_dft_ref_digest);
|
||||
@@ -412,20 +412,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -502,20 +504,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -589,20 +593,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -709,20 +715,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -815,20 +823,22 @@ where
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -923,20 +933,22 @@ pub fn test_vec_znx_dft_sub_negate_inplace<BR: Backend, BT: Backend>(
|
||||
|
||||
for j in 0..cols {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
|
||||
@@ -90,20 +90,22 @@ where
|
||||
|
||||
for j in 0..cols_out {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -205,18 +207,8 @@ where
|
||||
source.fill_bytes(res_dft_ref.data_mut());
|
||||
source.fill_bytes(res_dft_test.data_mut());
|
||||
|
||||
module_ref.vmp_apply_dft_to_dft(
|
||||
&mut res_dft_ref,
|
||||
&a_dft_ref,
|
||||
&pmat_ref,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vmp_apply_dft_to_dft(
|
||||
&mut res_dft_test,
|
||||
&a_dft_test,
|
||||
&pmat_test,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
module_ref.vmp_apply_dft_to_dft(&mut res_dft_ref, &a_dft_ref, &pmat_ref, scratch_ref.borrow());
|
||||
module_test.vmp_apply_dft_to_dft(&mut res_dft_test, &a_dft_test, &pmat_test, scratch_test.borrow());
|
||||
|
||||
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
|
||||
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
|
||||
@@ -229,20 +221,22 @@ where
|
||||
|
||||
for j in 0..cols_out {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
@@ -379,20 +373,22 @@ where
|
||||
|
||||
for j in 0..cols_out {
|
||||
module_ref.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_ref,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_ref,
|
||||
base2k,
|
||||
j,
|
||||
scratch_ref.borrow(),
|
||||
);
|
||||
module_test.vec_znx_big_normalize(
|
||||
base2k,
|
||||
&mut res_small_test,
|
||||
j,
|
||||
base2k,
|
||||
0,
|
||||
j,
|
||||
&res_big_test,
|
||||
base2k,
|
||||
j,
|
||||
scratch_test.borrow(),
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user