Support for bivariate convolution & normalization with offset (#126)

* Add bivariate-convolution
* Add pair-wise convolution + tests + benches
* Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md
* cross-base2k normalization with positive offset
* clippy & fix CI doctest avx compile error
* more streamlined bounds derivation for normalization
* Working cross-base2k normalization with pos/neg offset
* Update normalization API & tests
* Add glwe tensoring test
* Add relinearization + preliminary test
* Fix GGLWEToGGSW key infos
* Add (X,Y) convolution by const (1, Y) poly
* Faster normalization test + add bench for cnv_by_const
* Update changelog
This commit is contained in:
Jean-Philippe Bossuat
2025-12-21 16:56:42 +01:00
committed by GitHub
parent 76424d0ab5
commit 4e90e08a71
219 changed files with 6571 additions and 5041 deletions

View File

@@ -1,120 +1,97 @@
use crate::{
api::{
ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace,
VecZnxDftBytesOf, VecZnxDftZero,
},
layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos},
use crate::layouts::{
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Scratch, VecZnxBigToMut,
VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxViewMut,
};
impl<BE: Backend> BivariateTensoring<BE> for Module<BE>
where
Self: BivariateConvolution<BE>,
Scratch<BE>: ScratchTakeBasic,
{
pub trait CnvPVecAlloc<BE: Backend> {
fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, BE>;
fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, BE>;
}
pub trait BivariateTensoring<BE: Backend>
where
Self: BivariateConvolution<BE>,
Scratch<BE>: ScratchTakeBasic,
{
fn bivariate_tensoring<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
pub trait CnvPVecBytesOf {
fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize;
fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize;
}
pub trait Convolution<BE: Backend> {
fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
fn cnv_prepare_left<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
B: VecZnxDftToRef<BE>,
{
let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
R: CnvPVecLToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
A: VecZnxToRef + ZnxInfos;
let res_cols: usize = res.cols();
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize;
fn cnv_prepare_right<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: CnvPVecRToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
A: VecZnxToRef + ZnxInfos;
assert!(res_cols >= a_cols + b_cols - 1);
fn cnv_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize;
for res_col in 0..res_cols {
self.vec_znx_dft_zero(res, res_col);
}
fn cnv_by_const_apply_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize;
for a_col in 0..a_cols {
for b_col in 0..b_cols {
self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch);
}
}
}
}
impl<BE: Backend> BivariateConvolution<BE> for Module<BE>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic,
{
}
pub trait BivariateConvolution<BE: Backend>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic,
{
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
}
/// Evaluates a bivariate convolution over Z[X, Y] (x) Z[Y] mod (X^N + 1) where Y = 2^-K over the
/// selected columns and stores the result on the selected column, scaled by 2^{res_offset * Base2K}
///
/// Behavior is identical to [Convolution::cnv_apply_dft] with `b` treated as a constant polynomial
/// in the X variable, for example:
///```text
/// 1 X X^2 X^3
/// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ...
/// Y [a01, a11, a21, a31]
///
/// b = 1 [b0] = (b00 + b01 * 2^-K)
/// Y [b0]
/// ```
/// This method is intended to be used for multiplications by constants that are greater than the base2k.
#[allow(clippy::too_many_arguments)]
fn cnv_by_const_apply<R, A>(
&self,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &[i64],
scratch: &mut Scratch<BE>,
) where
R: VecZnxBigToMut<BE>,
A: VecZnxToRef;
#[allow(clippy::too_many_arguments)]
/// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the
/// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K}
/// Evaluates a bivariate convolution over Z[X, Y] (x) Z[X, Y] mod (X^N + 1) where Y = 2^-K over the
/// selected columns and stores the result on the selected column, scaled by 2^{res_offset * Base2K}
///
/// # Example
/// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ...
/// [a01, a11, a21, a31]
///```text
/// 1 X X^2 X^3
/// a = 1 [a00, a10, a20, a30] = (a00 + a01 * 2^-K) + (a10 + a11 * 2^-K) * X ...
/// Y [a01, a11, a21, a31]
///
/// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ...
/// [b01, b11, b21, b31]
/// b = 1 [b00, b10, b20, b30] = (b00 + b01 * 2^-K) + (b10 + b11 * 2^-K) * X ...
/// Y [b01, b11, b21, b31]
///
/// If k = 0:
/// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ...
/// [r01, r11, r21, r31]
/// [r02, r12, r22, r32]
/// [r03, r13, r23, r33]
/// [r04, r14, r24, r34]
/// If res_offset = 0:
///
/// If k = 1:
/// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ...
/// [r02, r12, r22, r32]
/// [r03, r13, r23, r33]
/// [r04, r14, r24, r34]
/// [r05, r15, r25, r35]
/// 1 X X^2 X^3
/// res = 1 [r00, r10, r20, r30] = (r00 + r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K) + ... * X + ...
/// Y [r01, r11, r21, r31]
/// Y^2[r02, r12, r22, r32]
/// Y^3[r03, r13, r23, r33]
///
/// If k = -1:
/// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ...
/// [ 0, 0, 0, 0]
/// [r01, r11, r21, r31]
/// [r02, r12, r22, r32]
/// [r03, r13, r23, r33]
/// If res_offset = 1:
///
/// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension.
fn bivariate_convolution_add<R, A, B>(
/// 1 X X^2 X^3
/// res = 1 [r01, r11, r21, r31] = (r01 + r02 * 2^-K + r03 * 2^-2K) + ... * X + ...
/// Y [r02, r12, r22, r32]
/// Y^2[r03, r13, r23, r33]
/// Y^3[ 0, 0, 0 , 0]
/// ```
/// If res.size() < a.size() + b.size() + k, result is truncated accordingly in the Y dimension.
fn cnv_apply_dft<R, A, B>(
&self,
k: i64,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
@@ -123,40 +100,27 @@ where
scratch: &mut Scratch<BE>,
) where
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
B: VecZnxDftToRef<BE>,
{
let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>;
let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1);
let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, b.size());
for a_limb in 0..a.size() {
self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0);
self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k);
}
}
fn cnv_pairwise_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize;
#[allow(clippy::too_many_arguments)]
fn bivariate_convolution<R, A, B>(
/// Evaluates the bivariate pair-wise convolution res = (a[i] + a[j]) * (b[i] + b[j]).
/// If i == j then calls [Convolution::cnv_apply_dft], i.e. res = a[i] * b[i].
/// See [Convolution::cnv_apply_dft] for information about the bivariate convolution.
fn cnv_pairwise_apply_dft<R, A, B>(
&self,
k: i64,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
i: usize,
j: usize,
scratch: &mut Scratch<BE>,
) where
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
B: VecZnxDftToRef<BE>,
{
self.vec_znx_dft_zero(res, res_col);
self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch);
}
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>;
}

View File

@@ -1,6 +1,6 @@
use crate::{
api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
api::{CnvPVecBytesOf, ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
layouts::{Backend, CnvPVecL, CnvPVecR, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
};
/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes.
@@ -56,6 +56,22 @@ pub trait ScratchTakeBasic
where
Self: TakeSlice,
{
fn take_cnv_pvec_left<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (CnvPVecL<&mut [u8], B>, &mut Self)
where
M: ModuleN + CnvPVecBytesOf,
{
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_cnv_pvec_left(cols, size));
(CnvPVecL::from_data(take_slice, module.n(), cols, size), rem_slice)
}
fn take_cnv_pvec_right<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (CnvPVecR<&mut [u8], B>, &mut Self)
where
M: ModuleN + CnvPVecBytesOf,
{
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_cnv_pvec_right(cols, size));
(CnvPVecR::from_data(take_slice, module.n(), cols, size), rem_slice)
}
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = self.take_slice(ScalarZnx::bytes_of(n, cols));
(ScalarZnx::from_data(take_slice, n, cols), rem_slice)
@@ -79,10 +95,7 @@ where
M: VecZnxBigBytesOf + ModuleN,
{
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_big(cols, size));
(
VecZnxBig::from_data(take_slice, module.n(), cols, size),
rem_slice,
)
(VecZnxBig::from_data(take_slice, module.n(), cols, size), rem_slice)
}
fn take_vec_znx_dft<M, B: Backend>(&mut self, module: &M, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self)
@@ -91,10 +104,7 @@ where
{
let (take_slice, rem_slice) = self.take_slice(module.bytes_of_vec_znx_dft(cols, size));
(
VecZnxDft::from_data(take_slice, module.n(), cols, size),
rem_slice,
)
(VecZnxDft::from_data(take_slice, module.n(), cols, size), rem_slice)
}
fn take_vec_znx_dft_slice<M, B: Backend>(
@@ -155,9 +165,6 @@ where
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = self.take_slice(MatZnx::bytes_of(n, rows, cols_in, cols_out, size));
(
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
rem_slice,
)
(MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size), rem_slice)
}
}

View File

@@ -19,11 +19,12 @@ pub trait VecZnxNormalize<B: Backend> {
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
fn vec_znx_normalize<R, A>(
&self,
res_basek: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_basek: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<B>,
) where

View File

@@ -164,11 +164,12 @@ pub trait VecZnxBigNormalizeTmpBytes {
pub trait VecZnxBigNormalize<B: Backend> {
fn vec_znx_big_normalize<R, A>(
&self,
res_base2k: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_base2k: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<B>,
) where

View File

@@ -0,0 +1,268 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{CnvPVecAlloc, Convolution, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxDftAlloc},
layouts::{Backend, CnvPVecL, CnvPVecR, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig},
source::Source,
};
pub fn bench_cnv_prepare_left<BE: Backend>(c: &mut Criterion, label: &str)
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let group_name: String = format!("cnv_prepare_left::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let mut source: Source = Source::new([0u8; 32]);
let base2k: usize = 12;
let c_size: usize = size + size - 1;
let module: Module<BE> = Module::<BE>::new(n as u64);
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(1, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, size);
a.fill_uniform(base2k, &mut source);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.cnv_prepare_left_tmp_bytes(c_size, size));
move || {
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
black_box(());
}
}
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
let log_n: usize = params[0];
let size: usize = params[1];
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
let mut runner = runner(1 << log_n, size);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_cnv_prepare_right<BE: Backend>(c: &mut Criterion, label: &str)
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let group_name: String = format!("cnv_prepare_right::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let mut source: Source = Source::new([0u8; 32]);
let base2k: usize = 12;
let c_size: usize = size + size - 1;
let module: Module<BE> = Module::<BE>::new(n as u64);
let mut a_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(1, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, size);
a.fill_uniform(base2k, &mut source);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.cnv_prepare_right_tmp_bytes(c_size, size));
move || {
module.cnv_prepare_right(&mut a_prep, &a, scratch.borrow());
black_box(());
}
}
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
let log_n: usize = params[0];
let size: usize = params[1];
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
let mut runner = runner(1 << log_n, size);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_cnv_apply_dft<BE: Backend>(c: &mut Criterion, label: &str)
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxDftAlloc<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let group_name: String = format!("cnv_apply_dft::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxDftAlloc<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let mut source: Source = Source::new([0u8; 32]);
let base2k: usize = 12;
let c_size: usize = size + size - 1;
let module: Module<BE> = Module::<BE>::new(n as u64);
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(1, size);
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(1, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, size);
let mut c_dft = module.vec_znx_dft_alloc(1, c_size);
a.fill_uniform(base2k, &mut source);
b.fill_uniform(base2k, &mut source);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
module
.cnv_apply_dft_tmp_bytes(c_size, 0, size, size)
.max(module.cnv_prepare_left_tmp_bytes(c_size, size))
.max(module.cnv_prepare_right_tmp_bytes(c_size, size)),
);
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
move || {
module.cnv_apply_dft(&mut c_dft, 0, 0, &a_prep, 0, &b_prep, 0, scratch.borrow());
black_box(());
}
}
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
let log_n: usize = params[0];
let size: usize = params[1];
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
let mut runner = runner(1 << log_n, size);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_cnv_pairwise_apply_dft<BE: Backend>(c: &mut Criterion, label: &str)
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxDftAlloc<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let group_name: String = format!("cnv_pairwise_apply_dft::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxDftAlloc<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let mut source: Source = Source::new([0u8; 32]);
let base2k: usize = 12;
let module: Module<BE> = Module::<BE>::new(n as u64);
let cols = 2;
let c_size: usize = size + size - 1;
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(cols, size);
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(cols, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
let mut c_dft = module.vec_znx_dft_alloc(1, c_size);
a.fill_uniform(base2k, &mut source);
b.fill_uniform(base2k, &mut source);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
module
.cnv_pairwise_apply_dft_tmp_bytes(c_size, 0, size, size)
.max(module.cnv_prepare_left_tmp_bytes(c_size, size))
.max(module.cnv_prepare_right_tmp_bytes(c_size, size)),
);
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
move || {
module.cnv_pairwise_apply_dft(&mut c_dft, 0, 0, &a_prep, &b_prep, 0, 1, scratch.borrow());
black_box(());
}
}
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
let log_n: usize = params[0];
let size: usize = params[1];
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
let mut runner = runner(1 << log_n, size);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_cnv_by_const_apply<BE: Backend>(c: &mut Criterion, label: &str)
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxBigAlloc<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let group_name: String = format!("cnv_by_const::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<BE: Backend>(n: usize, size: usize) -> impl FnMut()
where
Module<BE>: ModuleNew<BE> + Convolution<BE> + VecZnxBigAlloc<BE> + CnvPVecAlloc<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let mut source: Source = Source::new([0u8; 32]);
let base2k: usize = 12;
let module: Module<BE> = Module::<BE>::new(n as u64);
let cols = 2;
let c_size: usize = size + size - 1;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, size);
let mut c_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, c_size);
a.fill_uniform(base2k, &mut source);
let mut b = vec![0i64; size];
for x in &mut b {
*x = source.next_i64();
}
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.cnv_by_const_apply_tmp_bytes(c_size, 0, size, size));
move || {
module.cnv_by_const_apply(&mut c_big, 0, 0, &a, 0, &b, scratch.borrow());
black_box(());
}
}
for params in [[10, 1], [11, 2], [12, 4], [13, 8], [14, 16], [15, 32], [16, 64]] {
let log_n: usize = params[0];
let size: usize = params[1];
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x{}", 1 << log_n, size));
let mut runner = runner(1 << log_n, size);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -1,3 +1,4 @@
pub mod convolution;
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;

View File

@@ -404,7 +404,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_big_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
module.vec_znx_big_normalize(&mut res, base2k, 0, i, &a, base2k, i, scratch.borrow());
}
black_box(());
}

View File

@@ -0,0 +1,125 @@
use crate::{
api::{CnvPVecAlloc, CnvPVecBytesOf, Convolution},
layouts::{
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnxBigToMut,
VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxViewMut,
},
oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl},
};
impl<BE: Backend> CnvPVecAlloc<BE> for Module<BE>
where
BE: CnvPVecLAllocImpl<BE>,
{
fn cnv_pvec_left_alloc(&self, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, BE> {
BE::cnv_pvec_left_alloc_impl(self.n(), cols, size)
}
fn cnv_pvec_right_alloc(&self, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, BE> {
BE::cnv_pvec_right_alloc_impl(self.n(), cols, size)
}
}
impl<BE: Backend> CnvPVecBytesOf for Module<BE>
where
BE: CnvPVecBytesOfImpl,
{
fn bytes_of_cnv_pvec_left(&self, cols: usize, size: usize) -> usize {
BE::bytes_of_cnv_pvec_left_impl(self.n(), cols, size)
}
fn bytes_of_cnv_pvec_right(&self, cols: usize, size: usize) -> usize {
BE::bytes_of_cnv_pvec_right_impl(self.n(), cols, size)
}
}
impl<BE: Backend> Convolution<BE> for Module<BE>
where
BE: ConvolutionImpl<BE>,
{
fn cnv_prepare_left_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
BE::cnv_prepare_left_tmp_bytes_impl(self, res_size, a_size)
}
fn cnv_prepare_left<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: CnvPVecLToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
A: VecZnxToRef + ZnxInfos,
{
BE::cnv_prepare_left_impl(self, res, a, scratch);
}
fn cnv_prepare_right_tmp_bytes(&self, res_size: usize, a_size: usize) -> usize {
BE::cnv_prepare_right_tmp_bytes_impl(self, res_size, a_size)
}
fn cnv_prepare_right<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: CnvPVecRToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
A: VecZnxToRef + ZnxInfos,
{
BE::cnv_prepare_right_impl(self, res, a, scratch);
}
fn cnv_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize {
BE::cnv_apply_dft_tmp_bytes_impl(self, res_size, res_offset, a_size, b_size)
}
fn cnv_by_const_apply_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize {
BE::cnv_by_const_apply_tmp_bytes_impl(self, res_size, res_offset, a_size, b_size)
}
fn cnv_by_const_apply<R, A>(
&self,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &[i64],
scratch: &mut Scratch<BE>,
) where
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
BE::cnv_by_const_apply_impl(self, res, res_offset, res_col, a, a_col, b, scratch);
}
fn cnv_apply_dft<R, A, B>(
&self,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
scratch: &mut Scratch<BE>,
) where
R: VecZnxDftToMut<BE>,
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>,
{
BE::cnv_apply_dft_impl(self, res, res_offset, res_col, a, a_col, b, b_col, scratch);
}
fn cnv_pairwise_apply_dft_tmp_bytes(&self, res_size: usize, res_offset: usize, a_size: usize, b_size: usize) -> usize {
BE::cnv_pairwise_apply_dft_tmp_bytes(self, res_size, res_offset, a_size, b_size)
}
fn cnv_pairwise_apply_dft<R, A, B>(
&self,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
b: &B,
i: usize,
j: usize,
scratch: &mut Scratch<BE>,
) where
R: VecZnxDftToMut<BE>,
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>,
{
BE::cnv_pairwise_apply_dft_impl(self, res, res_offset, res_col, a, b, i, j, scratch);
}
}

View File

@@ -1,3 +1,4 @@
mod convolution;
mod module;
mod scratch;
mod svp_ppol;

View File

@@ -51,18 +51,19 @@ where
#[allow(clippy::too_many_arguments)]
fn vec_znx_normalize<R, A>(
&self,
res_basek: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_basek: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch)
B::vec_znx_normalize_impl(self, res, res_base2k, res_offset, res_col, a, a_base2k, a_col, scratch)
}
}

View File

@@ -51,7 +51,7 @@ where
impl<B> VecZnxBigBytesOf for Module<B>
where
B: Backend + VecZnxBigAllocBytesImpl<B>,
B: Backend + VecZnxBigAllocBytesImpl,
{
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
B::vec_znx_big_bytes_of_impl(self.n(), cols, size)
@@ -264,18 +264,19 @@ where
{
fn vec_znx_big_normalize<R, A>(
&self,
res_basek: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_basek: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch);
B::vec_znx_big_normalize_impl(self, res, res_base2k, res_offset, res_col, a, a_base2k, a_col, scratch);
}
}

View File

@@ -76,9 +76,7 @@ where
b_cols_out: usize,
b_size: usize,
) -> usize {
B::vmp_apply_dft_tmp_bytes_impl(
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
B::vmp_apply_dft_tmp_bytes_impl(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
}
}
@@ -109,9 +107,7 @@ where
b_cols_out: usize,
b_size: usize,
) -> usize {
B::vmp_apply_dft_to_dft_tmp_bytes_impl(
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
B::vmp_apply_dft_to_dft_tmp_bytes_impl(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
}
}
@@ -142,9 +138,7 @@ where
b_cols_out: usize,
b_size: usize,
) -> usize {
B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
B::vmp_apply_dft_to_dft_add_tmp_bytes_impl(self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size)
}
}

View File

@@ -0,0 +1,237 @@
use std::marker::PhantomData;
use crate::{
alloc_aligned,
layouts::{Backend, Data, DataMut, DataRef, DataView, DataViewMut, ZnxInfos, ZnxView},
oep::CnvPVecBytesOfImpl,
};
pub struct CnvPVecR<D: Data, BE: Backend> {
data: D,
n: usize,
size: usize,
cols: usize,
_phantom: PhantomData<BE>,
}
impl<D: Data, BE: Backend> ZnxInfos for CnvPVecR<D, BE> {
fn cols(&self) -> usize {
self.cols
}
fn n(&self) -> usize {
self.n
}
fn rows(&self) -> usize {
1
}
fn size(&self) -> usize {
self.size
}
}
impl<D: Data, BE: Backend> DataView for CnvPVecR<D, BE> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data, B: Backend> DataViewMut for CnvPVecR<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: DataRef, BE: Backend> ZnxView for CnvPVecR<D, BE> {
type Scalar = BE::ScalarPrep;
}
impl<D: DataRef + From<Vec<u8>>, B: Backend> CnvPVecR<D, B>
where
B: CnvPVecBytesOfImpl,
{
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(B::bytes_of_cnv_pvec_right_impl(n, cols, size));
Self {
data: data.into(),
n,
size,
cols,
_phantom: PhantomData,
}
}
pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == B::bytes_of_cnv_pvec_right_impl(n, cols, size));
Self {
data: data.into(),
n,
size,
cols,
_phantom: PhantomData,
}
}
}
impl<D: Data, B: Backend> CnvPVecR<D, B> {
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
_phantom: PhantomData,
}
}
}
pub struct CnvPVecL<D: Data, BE: Backend> {
data: D,
n: usize,
size: usize,
cols: usize,
_phantom: PhantomData<BE>,
}
impl<D: Data, BE: Backend> ZnxInfos for CnvPVecL<D, BE> {
fn cols(&self) -> usize {
self.cols
}
fn n(&self) -> usize {
self.n
}
fn rows(&self) -> usize {
1
}
fn size(&self) -> usize {
self.size
}
}
impl<D: Data, BE: Backend> DataView for CnvPVecL<D, BE> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data, B: Backend> DataViewMut for CnvPVecL<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: DataRef, BE: Backend> ZnxView for CnvPVecL<D, BE> {
type Scalar = BE::ScalarPrep;
}
impl<D: DataRef + From<Vec<u8>>, B: Backend> CnvPVecL<D, B>
where
B: CnvPVecBytesOfImpl,
{
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(B::bytes_of_cnv_pvec_left_impl(n, cols, size));
Self {
data: data.into(),
n,
size,
cols,
_phantom: PhantomData,
}
}
pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == B::bytes_of_cnv_pvec_left_impl(n, cols, size));
Self {
data: data.into(),
n,
size,
cols,
_phantom: PhantomData,
}
}
}
impl<D: Data, B: Backend> CnvPVecL<D, B> {
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
_phantom: PhantomData,
}
}
}
pub trait CnvPVecRToRef<BE: Backend> {
fn to_ref(&self) -> CnvPVecR<&[u8], BE>;
}
impl<D: DataRef, BE: Backend> CnvPVecRToRef<BE> for CnvPVecR<D, BE> {
fn to_ref(&self) -> CnvPVecR<&[u8], BE> {
CnvPVecR {
data: self.data.as_ref(),
n: self.n,
size: self.size,
cols: self.cols,
_phantom: self._phantom,
}
}
}
pub trait CnvPVecRToMut<BE: Backend> {
fn to_mut(&mut self) -> CnvPVecR<&mut [u8], BE>;
}
impl<D: DataMut, BE: Backend> CnvPVecRToMut<BE> for CnvPVecR<D, BE> {
fn to_mut(&mut self) -> CnvPVecR<&mut [u8], BE> {
CnvPVecR {
data: self.data.as_mut(),
n: self.n,
size: self.size,
cols: self.cols,
_phantom: self._phantom,
}
}
}
pub trait CnvPVecLToRef<BE: Backend> {
fn to_ref(&self) -> CnvPVecL<&[u8], BE>;
}
impl<D: DataRef, BE: Backend> CnvPVecLToRef<BE> for CnvPVecL<D, BE> {
fn to_ref(&self) -> CnvPVecL<&[u8], BE> {
CnvPVecL {
data: self.data.as_ref(),
n: self.n,
size: self.size,
cols: self.cols,
_phantom: self._phantom,
}
}
}
pub trait CnvPVecLToMut<BE: Backend> {
fn to_mut(&mut self) -> CnvPVecL<&mut [u8], BE>;
}
impl<D: DataMut, BE: Backend> CnvPVecLToMut<BE> for CnvPVecL<D, BE> {
fn to_mut(&mut self) -> CnvPVecL<&mut [u8], BE> {
CnvPVecL {
data: self.data.as_mut(),
n: self.n,
size: self.size,
cols: self.cols,
_phantom: self._phantom,
}
}
}

View File

@@ -223,22 +223,22 @@ impl<D: DataRef> VecZnx<D> {
let a: VecZnx<&[u8]> = self.to_ref();
let size: usize = a.size();
let prec: u32 = (base2k * size) as u32;
let prec: u32 = data[0].prec();
// 2^{base2k}
let base: Float = Float::with_val(prec, (1u64 << base2k) as f64);
let scale: Float = Float::with_val(prec, Float::u_pow_u(2, base2k as u32));
// y[i] = sum x[j][i] * 2^{-base2k*j}
(0..size).for_each(|i| {
if i == 0 {
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x);
*y /= &base;
*y /= &scale;
});
} else {
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x);
*y /= &base;
*y /= &scale;
});
}
});

View File

@@ -1,3 +1,4 @@
mod convolution;
mod encoding;
mod mat_znx;
mod module;
@@ -12,6 +13,7 @@ mod vec_znx_dft;
mod vmp_pmat;
mod znx_base;
pub use convolution::*;
pub use mat_znx::*;
pub use module::*;
pub use scalar_znx::*;

View File

@@ -123,10 +123,8 @@ where
panic!("cannot invert 0")
}
let g_exp: u64 = mod_exp_u64(
gal_el.unsigned_abs(),
(self.cyclotomic_order() - 1) as usize,
) & (self.cyclotomic_order() - 1) as u64;
let g_exp: u64 =
mod_exp_u64(gal_el.unsigned_abs(), (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64;
g_exp as i64 * gal_el.signum()
}
}

View File

@@ -187,11 +187,7 @@ impl<D: Data> VecZnx<D> {
impl<D: DataRef> fmt::Display for VecZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnx(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
writeln!(f, "VecZnx(n={}, cols={}, size={})", self.n, self.cols, self.size)?;
for col in 0..self.cols {
writeln!(f, "Column {col}:")?;

View File

@@ -93,7 +93,7 @@ where
impl<D: DataRef + From<Vec<u8>>, B: Backend> VecZnxBig<D, B>
where
B: VecZnxBigAllocBytesImpl<B>,
B: VecZnxBigAllocBytesImpl,
{
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(B::vec_znx_big_bytes_of_impl(n, cols, size));
@@ -172,11 +172,7 @@ impl<D: DataMut, B: Backend> VecZnxBigToMut<B> for VecZnxBig<D, B> {
impl<D: DataRef, B: Backend> fmt::Display for VecZnxBig<D, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnxBig(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
writeln!(f, "VecZnxBig(n={}, cols={}, size={})", self.n, self.cols, self.size)?;
for col in 0..self.cols {
writeln!(f, "Column {col}:")?;

View File

@@ -192,11 +192,7 @@ impl<D: DataMut, B: Backend> VecZnxDftToMut<B> for VecZnxDft<D, B> {
impl<D: DataRef, B: Backend> fmt::Display for VecZnxDft<D, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnxDft(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
writeln!(f, "VecZnxDft(n={}, cols={}, size={})", self.n, self.cols, self.size)?;
for col in 0..self.cols {
writeln!(f, "Column {col}:")?;

View File

@@ -65,11 +65,8 @@ pub trait ZnxView: ZnxInfos + DataView<D: DataRef> {
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
assert!(j < self.size(), "size: {} >= {}", j, self.size());
}
assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_ptr().add(offset) }
}
@@ -93,11 +90,8 @@ pub trait ZnxViewMut: ZnxView + DataViewMut<D: DataMut> {
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols(), "cols: {} >= {}", i, self.cols());
assert!(j < self.size(), "size: {} >= {}", j, self.size());
}
assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_mut_ptr().add(offset) }
}

View File

@@ -54,10 +54,7 @@ pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
/// Alignement must be a power of two and size a multiple of the alignement.
/// Allocated memory is initialized to zero.
fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
assert!(
align.is_power_of_two(),
"Alignment must be a power of two but is {align}"
);
assert!(align.is_power_of_two(), "Alignment must be a power of two but is {align}");
assert_eq!(
(size * size_of::<u8>()) % align,
0,
@@ -82,10 +79,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
/// Allocates a block of T aligned with [DEFAULTALIGN].
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert!(
align.is_power_of_two(),
"Alignment must be a power of two but is {align}"
);
assert!(align.is_power_of_two(), "Alignment must be a power of two but is {align}");
assert_eq!(
(size * size_of::<T>()) % align,

View File

@@ -0,0 +1,106 @@
use crate::layouts::{
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnxBigToMut,
VecZnxDftToMut, VecZnxToRef, ZnxInfos,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See the TODO reference implementation.
/// * See [crate::api::CnvPVecLAlloc] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait CnvPVecLAllocImpl<BE: Backend> {
fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, BE>;
fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, BE>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See the TODO reference implementation.
/// * See [crate::api::CnvPVecLBytesOf] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait CnvPVecBytesOfImpl {
fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize;
fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See the TODO reference implementation.
/// * See [crate::api::Convolution] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ConvolutionImpl<BE: Backend> {
fn cnv_prepare_left_tmp_bytes_impl(module: &Module<BE>, res_size: usize, a_size: usize) -> usize;
fn cnv_prepare_left_impl<R, A>(module: &Module<BE>, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: CnvPVecLToMut<BE>,
A: VecZnxToRef;
fn cnv_prepare_right_tmp_bytes_impl(module: &Module<BE>, res_size: usize, a_size: usize) -> usize;
fn cnv_prepare_right_impl<R, A>(module: &Module<BE>, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: CnvPVecRToMut<BE>,
A: VecZnxToRef + ZnxInfos;
fn cnv_apply_dft_tmp_bytes_impl(
module: &Module<BE>,
res_size: usize,
res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize;
fn cnv_by_const_apply_tmp_bytes_impl(
module: &Module<BE>,
res_size: usize,
res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize;
#[allow(clippy::too_many_arguments)]
fn cnv_by_const_apply_impl<R, A>(
module: &Module<BE>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &[i64],
scratch: &mut Scratch<BE>,
) where
R: VecZnxBigToMut<BE>,
A: VecZnxToRef;
#[allow(clippy::too_many_arguments)]
fn cnv_apply_dft_impl<R, A, B>(
module: &Module<BE>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
scratch: &mut Scratch<BE>,
) where
R: VecZnxDftToMut<BE>,
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>;
fn cnv_pairwise_apply_dft_tmp_bytes(
module: &Module<BE>,
res_size: usize,
res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize;
#[allow(clippy::too_many_arguments)]
fn cnv_pairwise_apply_dft_impl<R, A, B>(
module: &Module<BE>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
b: &B,
i: usize,
j: usize,
scratch: &mut Scratch<BE>,
) where
R: VecZnxDftToMut<BE>,
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>;
}

View File

@@ -1,3 +1,4 @@
mod convolution;
mod module;
mod scratch;
mod svp_ppol;
@@ -6,6 +7,7 @@ mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
pub use convolution::*;
pub use module::*;
pub use scratch::*;
pub use svp_ppol::*;

View File

@@ -29,11 +29,12 @@ pub unsafe trait VecZnxNormalizeImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
fn vec_znx_normalize_impl<R, A>(
module: &Module<B>,
res_basek: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_basek: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<B>,
) where

View File

@@ -34,7 +34,7 @@ pub unsafe trait VecZnxBigFromBytesImpl<B: Backend> {
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAllocBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAllocBytesImpl<B: Backend> {
pub unsafe trait VecZnxBigAllocBytesImpl {
fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize;
}
@@ -46,7 +46,7 @@ pub unsafe trait VecZnxBigAllocBytesImpl<B: Backend> {
pub unsafe trait VecZnxBigAddNormalImpl<B: Backend> {
fn add_normal_impl<R: VecZnxBigToMut<B>>(
module: &Module<B>,
res_basek: usize,
res_base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -240,11 +240,12 @@ pub unsafe trait VecZnxBigNormalizeTmpBytesImpl<B: Backend> {
pub unsafe trait VecZnxBigNormalizeImpl<B: Backend> {
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<B>,
res_basek: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_basek: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<B>,
) where

View File

@@ -0,0 +1,405 @@
use crate::{
layouts::{
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, VecZnx, VecZnxBig,
VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
},
reference::fft64::{
reim::{ReimAdd, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
reim4::{
Reim4Convolution, Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4Extract1BlkContiguous,
Reim4Save1BlkContiguous,
},
vec_znx_dft::vec_znx_dft_apply,
},
};
pub fn convolution_prepare_left<R, A, T, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, tmp: &mut T)
where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ Reim4Extract1BlkContiguous
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
+ ReimFromZnx
+ ReimZero,
R: CnvPVecLToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
A: VecZnxToRef,
T: VecZnxDftToMut<BE>,
{
convolution_prepare(table, res, a, tmp)
}
pub fn convolution_prepare_right<R, A, T, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, tmp: &mut T)
where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ Reim4Extract1BlkContiguous
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
+ ReimFromZnx
+ ReimZero,
R: CnvPVecRToMut<BE> + ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
A: VecZnxToRef,
T: VecZnxDftToMut<BE>,
{
convolution_prepare(table, res, a, tmp)
}
fn convolution_prepare<R, A, T, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, tmp: &mut T)
where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ Reim4Extract1BlkContiguous
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
+ ReimFromZnx
+ ReimZero,
R: ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
A: VecZnxToRef,
T: VecZnxDftToMut<BE>,
{
let a: &VecZnx<&[u8]> = &a.to_ref();
let tmp: &mut VecZnxDft<&mut [u8], BE> = &mut tmp.to_mut();
let cols: usize = res.cols();
assert_eq!(a.cols(), cols, "a.cols():{} != res.cols():{cols}", a.cols());
let res_size: usize = res.size();
let min_size: usize = res_size.min(a.size());
let m: usize = a.n() >> 1;
let n: usize = table.m() << 1;
let res_raw: &mut [f64] = res.raw_mut();
for i in 0..cols {
vec_znx_dft_apply(table, 1, 0, tmp, 0, a, i);
let tmp_raw: &[f64] = tmp.raw();
let res_col: &mut [f64] = &mut res_raw[i * n * res_size..];
for blk_i in 0..m / 4 {
BE::reim4_extract_1blk_contiguous(m, min_size, blk_i, &mut res_col[blk_i * res_size * 8..], tmp_raw);
BE::reim_zero(&mut res_col[blk_i * res_size * 8 + min_size * 8..(blk_i + 1) * res_size * 8]);
}
}
}
pub fn convolution_by_const_apply_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
let min_size: usize = res_size.min(a_size + b_size - 1);
size_of::<i64>() * (min_size + a_size) * 8
}
pub fn convolution_by_const_apply<R, A, BE>(
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &[i64],
tmp: &mut [i64],
) where
BE: Backend<ScalarBig = i64>
+ I64ConvolutionByConst1Coeff
+ I64ConvolutionByConst2Coeffs
+ I64Extract1BlkContiguous
+ I64Save1BlkContiguous,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
let res: &mut VecZnxBig<&mut [u8], BE> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let n: usize = res.n();
assert_eq!(a.n(), n);
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.len();
let bound: usize = a_size + b_size - 1;
let min_size: usize = res_size.min(bound);
let offset: usize = res_offset.min(bound);
let a_sl: usize = n * a.cols();
let res_sl: usize = n * res.cols();
let res_raw: &mut [i64] = res.raw_mut();
let a_raw: &[i64] = a.raw();
let a_idx: usize = n * a_col;
let res_idx: usize = n * res_col;
let (res_blk, a_blk) = tmp[..(min_size + a_size) * 8].split_at_mut(min_size * 8);
for blk_i in 0..n / 8 {
BE::i64_extract_1blk_contiguous(a_sl, a_idx, a_size, blk_i, a_blk, a_raw);
BE::i64_convolution_by_const(res_blk, min_size, offset, a_blk, a_size, b);
BE::i64_save_1blk_contiguous(res_sl, res_idx, min_size, blk_i, res_raw, res_blk);
}
for j in min_size..res_size {
res.zero_at(res_col, j);
}
}
pub fn convolution_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
let min_size: usize = res_size.min(a_size + b_size - 1);
size_of::<f64>() * 8 * min_size
}
#[allow(clippy::too_many_arguments)]
pub fn convolution_apply_dft<R, A, B, BE>(
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
tmp: &mut [f64],
) where
BE: Backend<ScalarPrep = f64> + Reim4Save1BlkContiguous + Reim4Convolution1Coeff + Reim4Convolution2Coeffs,
R: VecZnxDftToMut<BE>,
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>,
{
let res: &mut VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], BE> = &a.to_ref();
let b: &CnvPVecR<&[u8], BE> = &b.to_ref();
let n: usize = res.n();
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
let m: usize = n >> 1;
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
let bound: usize = a_size + b_size - 1;
let min_size: usize = res_size.min(bound);
let offset: usize = res_offset.min(bound);
let dst: &mut [f64] = res.raw_mut();
let a_raw: &[f64] = a.raw();
let b_raw: &[f64] = b.raw();
let mut a_idx: usize = a_col * n * a_size;
let mut b_idx: usize = b_col * n * b_size;
let a_offset: usize = a_size * 8;
let b_offset: usize = b_size * 8;
for blk_i in 0..m / 4 {
BE::reim4_convolution(tmp, min_size, offset, &a_raw[a_idx..], a_size, &b_raw[b_idx..], b_size);
BE::reim4_save_1blk_contiguous(m, min_size, blk_i, dst, tmp);
a_idx += a_offset;
b_idx += b_offset;
}
for j in min_size..res_size {
res.zero_at(res_col, j);
}
}
pub fn convolution_pairwise_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
convolution_apply_dft_tmp_bytes(res_size, a_size, b_size) + (a_size + b_size) * size_of::<f64>() * 8
}
#[allow(clippy::too_many_arguments)]
pub fn convolution_pairwise_apply_dft<R, A, B, BE>(
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
b: &B,
col_i: usize,
col_j: usize,
tmp: &mut [f64],
) where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ ReimAdd
+ ReimCopy
+ Reim4Save1BlkContiguous
+ Reim4Convolution1Coeff
+ Reim4Convolution2Coeffs,
R: VecZnxDftToMut<BE>,
A: CnvPVecLToRef<BE>,
B: CnvPVecRToRef<BE>,
{
if col_i == col_j {
convolution_apply_dft(res, res_offset, res_col, a, col_i, b, col_j, tmp);
return;
}
let res: &mut VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], BE> = &a.to_ref();
let b: &CnvPVecR<&[u8], BE> = &b.to_ref();
let n: usize = res.n();
let m: usize = n >> 1;
assert_eq!(a.n(), n);
assert_eq!(b.n(), n);
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
assert_eq!(
tmp.len(),
convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size) / size_of::<f64>()
);
let bound: usize = a_size + b_size - 1;
let min_size: usize = res_size.min(bound);
let offset: usize = res_offset.min(bound);
let res_raw: &mut [f64] = res.raw_mut();
let a_raw: &[f64] = a.raw();
let b_raw: &[f64] = b.raw();
let a_row_size: usize = a_size * 8;
let b_row_size: usize = b_size * 8;
let mut a0_idx: usize = col_i * n * a_size;
let mut a1_idx: usize = col_j * n * a_size;
let mut b0_idx: usize = col_i * n * b_size;
let mut b1_idx: usize = col_j * n * b_size;
let (tmp_a, tmp) = tmp.split_at_mut(a_row_size);
let (tmp_b, tmp_res) = tmp.split_at_mut(b_row_size);
for blk_i in 0..m / 4 {
let a0: &[f64] = &a_raw[a0_idx..];
let a1: &[f64] = &a_raw[a1_idx..];
let b0: &[f64] = &b_raw[b0_idx..];
let b1: &[f64] = &b_raw[b1_idx..];
BE::reim_add(tmp_a, &a0[..a_row_size], &a1[..a_row_size]);
BE::reim_add(tmp_b, &b0[..b_row_size], &b1[..b_row_size]);
BE::reim4_convolution(tmp_res, min_size, offset, tmp_a, a_size, tmp_b, b_size);
BE::reim4_save_1blk_contiguous(m, min_size, blk_i, res_raw, tmp_res);
a0_idx += a_row_size;
a1_idx += a_row_size;
b0_idx += b_row_size;
b1_idx += b_row_size;
}
for j in min_size..res_size {
res.zero_at(res_col, j);
}
}
pub trait I64Extract1BlkContiguous {
fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]);
}
pub trait I64Save1BlkContiguous {
fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]);
}
#[inline(always)]
pub fn i64_extract_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
debug_assert!(blk < (n >> 3));
debug_assert!(dst.len() >= rows * 8, "dst.len(): {} < rows*8: {}", dst.len(), 8 * rows);
let offset: usize = offset + (blk << 3);
// src = 8-values chunks spaced by n, dst = sequential 8-values chunks
let src_rows = src.chunks_exact(n).take(rows);
let dst_chunks = dst.chunks_exact_mut(8).take(rows);
for (dst_chunk, src_row) in dst_chunks.zip(src_rows) {
dst_chunk.copy_from_slice(&src_row[offset..offset + 8]);
}
}
#[inline(always)]
pub fn i64_save_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
debug_assert!(blk < (n >> 3));
debug_assert!(src.len() >= rows * 8);
let offset: usize = offset + (blk << 3);
// dst = 4-values chunks spaced by m, src = sequential 4-values chunks
let dst_rows = dst.chunks_exact_mut(n).take(rows);
let src_chunks = src.chunks_exact(8).take(rows);
for (dst_row, src_chunk) in dst_rows.zip(src_chunks) {
dst_row[offset..offset + 8].copy_from_slice(src_chunk);
}
}
pub trait I64ConvolutionByConst1Coeff {
fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]);
}
#[inline(always)]
pub fn i64_convolution_by_const_1coeff_ref(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
dst.fill(0);
let b_size: usize = b.len();
if k >= a_size + b_size {
return;
}
let j_min: usize = k.saturating_sub(a_size - 1);
let j_max: usize = (k + 1).min(b_size);
for j in j_min..j_max {
let ai: &[i64] = &a[8 * (k - j)..];
let bi: i64 = b[j];
dst[0] = dst[0].wrapping_add(ai[0].wrapping_mul(bi));
dst[1] = dst[1].wrapping_add(ai[1].wrapping_mul(bi));
dst[2] = dst[2].wrapping_add(ai[2].wrapping_mul(bi));
dst[3] = dst[3].wrapping_add(ai[3].wrapping_mul(bi));
dst[4] = dst[4].wrapping_add(ai[4].wrapping_mul(bi));
dst[5] = dst[5].wrapping_add(ai[5].wrapping_mul(bi));
dst[6] = dst[6].wrapping_add(ai[6].wrapping_mul(bi));
dst[7] = dst[7].wrapping_add(ai[7].wrapping_mul(bi));
}
}
#[inline(always)]
pub(crate) fn as_arr_i64<const size: usize>(x: &[i64]) -> &[i64; size] {
debug_assert!(x.len() >= size, "x.len():{} < size:{}", x.len(), size);
unsafe { &*(x.as_ptr() as *const [i64; size]) }
}
#[inline(always)]
pub(crate) fn as_arr_i64_mut<const size: usize>(x: &mut [i64]) -> &mut [i64; size] {
debug_assert!(x.len() >= size, "x.len():{} < size:{}", x.len(), size);
unsafe { &mut *(x.as_mut_ptr() as *mut [i64; size]) }
}
pub trait I64ConvolutionByConst2Coeffs {
fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]);
}
#[inline(always)]
pub fn i64_convolution_by_const_2coeffs_ref(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
i64_convolution_by_const_1coeff_ref(k, as_arr_i64_mut(&mut dst[..8]), a, a_size, b);
i64_convolution_by_const_1coeff_ref(k + 1, as_arr_i64_mut(&mut dst[8..]), a, a_size, b);
}
impl<BE: Backend> I64ConvolutionByConst<BE> for BE where Self: I64ConvolutionByConst1Coeff + I64ConvolutionByConst2Coeffs {}
pub trait I64ConvolutionByConst<BE: Backend>
where
BE: I64ConvolutionByConst1Coeff + I64ConvolutionByConst2Coeffs,
{
fn i64_convolution_by_const(dst: &mut [i64], dst_size: usize, offset: usize, a: &[i64], a_size: usize, b: &[i64]) {
assert!(a_size > 0);
for k in (0..dst_size - 1).step_by(2) {
BE::i64_convolution_by_const_2coeffs(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b);
}
if !dst_size.is_multiple_of(2) {
let k: usize = dst_size - 1;
BE::i64_convolution_by_const_1coeff(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b);
}
}
}

View File

@@ -1,3 +1,4 @@
pub mod convolution;
pub mod reim;
pub mod reim4;
pub mod svp;

View File

@@ -12,26 +12,10 @@ pub fn fft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R
if m <= 16 {
match m {
1 => {}
2 => fft2_ref(
as_arr_mut::<2, R>(re),
as_arr_mut::<2, R>(im),
*as_arr::<2, R>(omg),
),
4 => fft4_ref(
as_arr_mut::<4, R>(re),
as_arr_mut::<4, R>(im),
*as_arr::<4, R>(omg),
),
8 => fft8_ref(
as_arr_mut::<8, R>(re),
as_arr_mut::<8, R>(im),
*as_arr::<8, R>(omg),
),
16 => fft16_ref(
as_arr_mut::<16, R>(re),
as_arr_mut::<16, R>(im),
*as_arr::<16, R>(omg),
),
2 => fft2_ref(as_arr_mut::<2, R>(re), as_arr_mut::<2, R>(im), *as_arr::<2, R>(omg)),
4 => fft4_ref(as_arr_mut::<4, R>(re), as_arr_mut::<4, R>(im), *as_arr::<4, R>(omg)),
8 => fft8_ref(as_arr_mut::<8, R>(re), as_arr_mut::<8, R>(im), *as_arr::<8, R>(omg)),
16 => fft16_ref(as_arr_mut::<16, R>(re), as_arr_mut::<16, R>(im), *as_arr::<16, R>(omg)),
_ => {}
}
} else if m <= 2048 {
@@ -257,12 +241,7 @@ fn fft_bfs_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mu
while mm > 16 {
let h: usize = mm >> 2;
for off in (0..m).step_by(mm) {
bitwiddle_fft_ref(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, R>(&omg[pos..]),
);
bitwiddle_fft_ref(h, &mut re[off..], &mut im[off..], as_arr::<4, R>(&omg[pos..]));
pos += 4;
}
mm = h
@@ -289,14 +268,7 @@ fn twiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R],
let (im_lhs, im_rhs) = im.split_at_mut(h);
for i in 0..h {
cplx_twiddle(
&mut re_lhs[i],
&mut im_lhs[i],
&mut re_rhs[i],
&mut im_rhs[i],
romg,
iomg,
);
cplx_twiddle(&mut re_lhs[i], &mut im_lhs[i], &mut re_rhs[i], &mut im_rhs[i], romg, iomg);
}
}

View File

@@ -11,26 +11,10 @@ pub fn ifft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [
if m <= 16 {
match m {
1 => {}
2 => ifft2_ref(
as_arr_mut::<2, R>(re),
as_arr_mut::<2, R>(im),
*as_arr::<2, R>(omg),
),
4 => ifft4_ref(
as_arr_mut::<4, R>(re),
as_arr_mut::<4, R>(im),
*as_arr::<4, R>(omg),
),
8 => ifft8_ref(
as_arr_mut::<8, R>(re),
as_arr_mut::<8, R>(im),
*as_arr::<8, R>(omg),
),
16 => ifft16_ref(
as_arr_mut::<16, R>(re),
as_arr_mut::<16, R>(im),
*as_arr::<16, R>(omg),
),
2 => ifft2_ref(as_arr_mut::<2, R>(re), as_arr_mut::<2, R>(im), *as_arr::<2, R>(omg)),
4 => ifft4_ref(as_arr_mut::<4, R>(re), as_arr_mut::<4, R>(im), *as_arr::<4, R>(omg)),
8 => ifft8_ref(as_arr_mut::<8, R>(re), as_arr_mut::<8, R>(im), *as_arr::<8, R>(omg)),
16 => ifft16_ref(as_arr_mut::<16, R>(re), as_arr_mut::<16, R>(im), *as_arr::<16, R>(omg)),
_ => {}
}
} else if m <= 2048 {
@@ -72,12 +56,7 @@ fn ifft_bfs_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R],
while h < m_half {
let mm: usize = h << 2;
for off in (0..m).step_by(mm) {
inv_bitwiddle_ifft_ref(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, R>(&omg[pos..]),
);
inv_bitwiddle_ifft_ref(h, &mut re[off..], &mut im[off..], as_arr::<4, R>(&omg[pos..]));
pos += 4;
}
h = mm;
@@ -284,14 +263,7 @@ fn inv_twiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut
let (im_lhs, im_rhs) = im.split_at_mut(h);
for i in 0..h {
inv_twiddle(
&mut re_lhs[i],
&mut im_lhs[i],
&mut re_rhs[i],
&mut im_rhs[i],
romg,
iomg,
);
inv_twiddle(&mut re_lhs[i], &mut im_lhs[i], &mut re_rhs[i], &mut im_rhs[i], romg, iomg);
}
}

View File

@@ -35,7 +35,7 @@ pub use zero::*;
#[inline(always)]
pub(crate) fn as_arr<const size: usize, R: Float + FloatConst>(x: &[R]) -> &[R; size] {
debug_assert!(x.len() >= size);
debug_assert!(x.len() >= size, "x.len():{} < size:{}", x.len(), size);
unsafe { &*(x.as_ptr() as *const [R; size]) }
}

View File

@@ -1,15 +1,34 @@
use crate::reference::fft64::reim::as_arr;
use crate::reference::fft64::reim::{as_arr, as_arr_mut, reim_zero_ref};
#[inline(always)]
pub fn reim4_extract_1blk_from_reim_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
let mut offset: usize = blk << 2;
pub fn reim4_extract_1blk_from_reim_contiguous_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
debug_assert!(blk < (m >> 2));
debug_assert!(dst.len() >= 2 * rows * 4);
for chunk in dst.chunks_exact_mut(4).take(2 * rows) {
chunk.copy_from_slice(&src[offset..offset + 4]);
offset += m
let offset: usize = blk << 2;
// src = 4-values chunks spaced by m, dst = sequential 4-values chunks
let src_rows = src.chunks_exact(m).take(2 * rows);
let dst_chunks = dst.chunks_exact_mut(4).take(2 * rows);
for (dst_chunk, src_row) in dst_chunks.zip(src_rows) {
dst_chunk.copy_from_slice(&src_row[offset..offset + 4]);
}
}
#[inline(always)]
pub fn reim4_save_1blk_to_reim_contiguous_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
debug_assert!(blk < (m >> 2));
debug_assert!(src.len() >= 2 * rows * 4);
let offset: usize = blk << 2;
// dst = 4-values chunks spaced by m, src = sequential 4-values chunks
let dst_rows = dst.chunks_exact_mut(m).take(2 * rows);
let src_chunks = src.chunks_exact(4).take(2 * rows);
for (dst_row, src_chunk) in dst_rows.zip(src_chunks) {
dst_row[offset..offset + 4].copy_from_slice(src_chunk);
}
}
@@ -53,7 +72,7 @@ pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize,
debug_assert!(dst.len() >= offset + 3 * m + 4);
debug_assert!(src.len() >= 16);
let dst_off = &mut dst[offset..offset + 4];
let dst_off: &mut [f64] = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[0..4]);
} else {
@@ -64,7 +83,7 @@ pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize,
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
let dst_off: &mut [f64] = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[4..8]);
} else {
@@ -76,7 +95,7 @@ pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize,
offset += m;
let dst_off = &mut dst[offset..offset + 4];
let dst_off: &mut [f64] = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[8..12]);
} else {
@@ -87,7 +106,7 @@ pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize,
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
let dst_off: &mut [f64] = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[12..16]);
} else {
@@ -132,10 +151,7 @@ pub fn reim4_vec_mat2cols_product_ref(
{
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
assert!(
v.len() >= nrows * 16,
"v must be at least nrows * 16 doubles"
);
assert!(v.len() >= nrows * 16, "v must be at least nrows * 16 doubles");
}
// zero accumulators
@@ -161,11 +177,7 @@ pub fn reim4_vec_mat2cols_2ndcol_product_ref(
) {
#[cfg(debug_assertions)]
{
assert!(
dst.len() >= 8,
"dst must be at least 8 doubles but is {}",
dst.len()
);
assert!(dst.len() >= 8, "dst must be at least 8 doubles but is {}", dst.len());
assert!(
u.len() >= nrows * 8,
"u must be at least nrows={} * 8 doubles but is {}",
@@ -201,3 +213,57 @@ pub fn reim4_add_mul(dst: &mut [f64; 8], a: &[f64; 8], b: &[f64; 8]) {
dst[k + 4] += ar * bi + ai * br;
}
}
#[inline(always)]
pub fn reim4_convolution_1coeff_ref(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
reim_zero_ref(dst);
if k >= a_size + b_size {
return;
}
let j_min: usize = k.saturating_sub(a_size - 1);
let j_max: usize = (k + 1).min(b_size);
for j in j_min..j_max {
reim4_add_mul(dst, as_arr(&a[8 * (k - j)..]), as_arr(&b[8 * j..]));
}
}
#[inline(always)]
pub fn reim4_convolution_2coeffs_ref(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
reim4_convolution_1coeff_ref(k, as_arr_mut(dst), a, a_size, b, b_size);
reim4_convolution_1coeff_ref(k + 1, as_arr_mut(&mut dst[8..]), a, a_size, b, b_size);
}
#[inline(always)]
pub fn reim4_add_mul_b_real_const(dst: &mut [f64; 8], a: &[f64; 8], b: f64) {
for k in 0..4 {
let ar: f64 = a[k];
let ai: f64 = a[k + 4];
dst[k] += ar * b;
dst[k + 4] += ai * b;
}
}
#[inline(always)]
pub fn reim4_convolution_by_real_const_1coeff_ref(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) {
reim_zero_ref(dst);
let b_size: usize = b.len();
if k >= a_size + b_size {
return;
}
let j_min: usize = k.saturating_sub(a_size - 1);
let j_max: usize = (k + 1).min(b_size);
for j in j_min..j_max {
reim4_add_mul_b_real_const(dst, as_arr(&a[8 * (k - j)..]), b[j]);
}
}
#[inline(always)]
pub fn reim4_convolution_by_real_const_2coeffs_ref(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) {
reim4_convolution_by_real_const_1coeff_ref(k, as_arr_mut(dst), a, a_size, b);
reim4_convolution_by_real_const_1coeff_ref(k + 1, as_arr_mut(&mut dst[8..]), a, a_size, b);
}

View File

@@ -2,8 +2,14 @@ mod arithmetic_ref;
pub use arithmetic_ref::*;
pub trait Reim4Extract1Blk {
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
use crate::{layouts::Backend, reference::fft64::reim::as_arr_mut};
pub trait Reim4Extract1BlkContiguous {
fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
}
pub trait Reim4Save1BlkContiguous {
fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
}
pub trait Reim4Save1Blk {
@@ -25,3 +31,63 @@ pub trait Reim4Mat2ColsProd {
pub trait Reim4Mat2Cols2ndColProd {
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
}
pub trait Reim4Convolution1Coeff {
fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize);
}
pub trait Reim4Convolution2Coeffs {
fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize);
}
pub trait Reim4ConvolutionByRealConst1Coeff {
fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]);
}
pub trait Reim4ConvolutionByRealConst2Coeffs {
fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]);
}
impl<BE: Backend> Reim4Convolution<BE> for BE where Self: Reim4Convolution1Coeff + Reim4Convolution2Coeffs {}
pub trait Reim4Convolution<BE: Backend>
where
BE: Reim4Convolution1Coeff + Reim4Convolution2Coeffs,
{
fn reim4_convolution(dst: &mut [f64], dst_size: usize, offset: usize, a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
assert!(a_size > 0);
assert!(b_size > 0);
for k in (0..dst_size - 1).step_by(2) {
BE::reim4_convolution_2coeffs(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b, b_size);
}
if !dst_size.is_multiple_of(2) {
let k: usize = dst_size - 1;
BE::reim4_convolution_1coeff(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b, b_size);
}
}
}
impl<BE: Backend> Reim4ConvolutionByRealConst<BE> for BE where
Self: Reim4ConvolutionByRealConst1Coeff + Reim4ConvolutionByRealConst2Coeffs
{
}
pub trait Reim4ConvolutionByRealConst<BE: Backend>
where
BE: Reim4ConvolutionByRealConst1Coeff + Reim4ConvolutionByRealConst2Coeffs,
{
fn reim4_convolution_by_real_const(dst: &mut [f64], dst_size: usize, offset: usize, a: &[f64], a_size: usize, b: &[f64]) {
assert!(a_size > 0);
for k in (0..dst_size - 1).step_by(2) {
BE::reim4_convolution_by_real_const_2coeffs(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b);
}
if !dst_size.is_multiple_of(2) {
let k: usize = dst_size - 1;
BE::reim4_convolution_by_real_const_1coeff(k + offset, as_arr_mut(&mut dst[8 * k..]), a, a_size, b);
}
}
}

View File

@@ -9,13 +9,14 @@ use crate::{
reference::{
vec_znx::{
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_normalize_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace,
vec_znx_sub_negate_inplace,
},
znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNegate,
ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero,
znx_add_normal_f64_ref,
ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep,
ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero, znx_add_normal_f64_ref,
},
},
source::Source,
@@ -231,15 +232,17 @@ where
}
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
2 * n * size_of::<i64>()
vec_znx_normalize_tmp_bytes(n)
}
#[allow(clippy::too_many_arguments)]
pub fn vec_znx_big_normalize<R, A, BE>(
res_base2k: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_base2k: usize,
a: &A,
a_base2k: usize,
a_col: usize,
carry: &mut [i64],
) where
@@ -256,7 +259,9 @@ pub fn vec_znx_big_normalize<R, A, BE>(
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxExtractDigitAddMul
+ ZnxNormalizeDigit,
+ ZnxNormalizeDigit
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFinalStepInplace,
{
let a: VecZnxBig<&[u8], _> = a.to_ref();
let a_vznx: VecZnx<&[u8]> = VecZnx {
@@ -267,7 +272,7 @@ pub fn vec_znx_big_normalize<R, A, BE>(
max_size: a.max_size,
};
vec_znx_normalize::<_, _, BE>(res_base2k, res, res_col, a_base2k, &a_vznx, a_col, carry);
vec_znx_normalize::<_, _, BE>(res, res_base2k, res_offset, res_col, &a_vznx, a_base2k, a_col, carry);
}
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
@@ -290,18 +295,13 @@ pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
bound * scale,
source,
)
znx_add_normal_f64_ref(res.at_mut(res_col, limb), sigma * scale, bound * scale, source)
}
pub fn test_vec_znx_big_add_normal<B>(module: &Module<B>)
where
Module<B>: VecZnxBigAddNormal<B>,
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl,
{
let n: usize = module.n();
let base2k: usize = 17;
@@ -325,12 +325,7 @@ where
})
} else {
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}",
std,
sigma * sqrt2
);
assert!((std - sigma * sqrt2).abs() < 0.1, "std={} ~!= {}", std, sigma * sqrt2);
}
})
});

View File

@@ -4,7 +4,10 @@ use crate::{
oep::VecZnxDftAllocBytesImpl,
reference::fft64::{
reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
reim4::{Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks},
reim4::{
Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk,
Reim4Save2Blks,
},
vec_znx_dft::vec_znx_dft_apply,
},
};
@@ -17,7 +20,7 @@ pub fn vmp_prepare_tmp_bytes(n: usize) -> usize {
pub fn vmp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, pmat: &mut R, mat: &A, tmp: &mut [f64])
where
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1BlkContiguous,
R: VmpPMatToMut<BE>,
A: MatZnxToRef,
{
@@ -34,13 +37,7 @@ where
res.cols_in(),
a.cols_in()
);
assert_eq!(
res.rows(),
a.rows(),
"res.rows: {} != a.rows: {}",
res.rows(),
a.rows()
);
assert_eq!(res.rows(), a.rows(), "res.rows: {} != a.rows: {}", res.rows(), a.rows());
assert_eq!(
res.cols_out(),
a.cols_out(),
@@ -48,13 +45,7 @@ where
res.cols_out(),
a.cols_out()
);
assert_eq!(
res.size(),
a.size(),
"res.size: {} != a.size: {}",
res.size(),
a.size()
);
assert_eq!(res.size(), a.size(), "res.size: {} != a.size: {}", res.size(), a.size());
}
let nrows: usize = a.cols_in() * a.rows();
@@ -70,7 +61,7 @@ pub(crate) fn vmp_prepare_core<REIM>(
ncols: usize,
tmp: &mut [f64],
) where
REIM: ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
REIM: ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1BlkContiguous,
{
let m: usize = table.m();
let n: usize = m << 1;
@@ -99,7 +90,7 @@ pub(crate) fn vmp_prepare_core<REIM>(
};
for blk_i in 0..m >> 2 {
REIM::reim4_extract_1blk(m, 1, blk_i, &mut dst[blk_i * offset..], tmp);
REIM::reim4_extract_1blk_contiguous(m, 1, blk_i, &mut dst[blk_i * offset..], tmp);
}
}
}
@@ -116,7 +107,7 @@ where
+ VecZnxDftAllocBytesImpl<BE>
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
+ ReimZero
+ Reim4Extract1Blk
+ Reim4Extract1BlkContiguous
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
@@ -168,7 +159,7 @@ pub fn vmp_apply_dft_to_dft<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, tmp_bytes
where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ Reim4Extract1Blk
+ Reim4Extract1BlkContiguous
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
@@ -207,7 +198,7 @@ pub fn vmp_apply_dft_to_dft_add<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, limb_
where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ Reim4Extract1Blk
+ Reim4Extract1BlkContiguous
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
@@ -239,16 +230,7 @@ where
let a_raw: &[f64] = a.raw();
let res_raw: &mut [f64] = res.raw_mut();
vmp_apply_dft_to_dft_core::<false, BE>(
n,
res_raw,
a_raw,
pmat_raw,
limb_offset,
nrows,
ncols,
tmp_bytes,
)
vmp_apply_dft_to_dft_core::<false, BE>(n, res_raw, a_raw, pmat_raw, limb_offset, nrows, ncols, tmp_bytes)
}
#[allow(clippy::too_many_arguments)]
@@ -263,7 +245,7 @@ fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
tmp_bytes: &mut [f64],
) where
REIM: ReimZero
+ Reim4Extract1Blk
+ Reim4Extract1BlkContiguous
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
@@ -299,41 +281,23 @@ fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
for blk_i in 0..(m >> 2) {
let mat_blk_start: &[f64] = &pmat[blk_i * (8 * nrows * ncols)..];
REIM::reim4_extract_1blk(m, row_max, blk_i, extracted_blk, a);
REIM::reim4_extract_1blk_contiguous(m, row_max, blk_i, extracted_blk, a);
if limb_offset.is_multiple_of(2) {
for (col_res, col_pmat) in (0..).step_by(2).zip((limb_offset..col_max - 1).step_by(2)) {
let col_offset: usize = col_pmat * (8 * nrows);
REIM::reim4_mat2cols_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_mat2cols_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
}
} else {
let col_offset: usize = (limb_offset - 1) * (8 * nrows);
REIM::reim4_mat2cols_2ndcol_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_mat2cols_2ndcol_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
REIM::reim4_save_1blk::<OVERWRITE>(m, blk_i, res, mat2cols_output);
for (col_res, col_pmat) in (1..)
.step_by(2)
.zip((limb_offset + 1..col_max - 1).step_by(2))
{
for (col_res, col_pmat) in (1..).step_by(2).zip((limb_offset + 1..col_max - 1).step_by(2)) {
let col_offset: usize = col_pmat * (8 * nrows);
REIM::reim4_mat2cols_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_mat2cols_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
}
}
@@ -344,26 +308,11 @@ fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
if last_col >= limb_offset {
if ncols == col_max {
REIM::reim4_mat1col_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_mat1col_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
} else {
REIM::reim4_mat2cols_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_mat2cols_prod(row_max, mat2cols_output, extracted_blk, &mat_blk_start[col_offset..]);
}
REIM::reim4_save_1blk::<OVERWRITE>(
m,
blk_i,
&mut res[(last_col - limb_offset) * n..],
mat2cols_output,
);
REIM::reim4_save_1blk::<OVERWRITE>(m, blk_i, &mut res[(last_col - limb_offset) * n..], mat2cols_output);
}
}
}

View File

@@ -24,10 +24,7 @@ where
{
assert_eq!(tmp.len(), res.n());
debug_assert!(
_n_out > _n_in,
"invalid a: output ring degree should be greater"
);
debug_assert!(_n_out > _n_in, "invalid a: output ring degree should be greater");
a[1..].iter().for_each(|ai| {
debug_assert_eq!(
ai.to_ref().n(),

View File

@@ -14,15 +14,17 @@ use crate::{
};
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
2 * n * size_of::<i64>()
3 * n * size_of::<i64>()
}
#[allow(clippy::too_many_arguments)]
pub fn vec_znx_normalize<R, A, ZNXARI>(
res_base2k: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_base2k: usize,
a: &A,
a_base2k: usize,
a_col: usize,
carry: &mut [i64],
) where
@@ -38,14 +40,40 @@ pub fn vec_znx_normalize<R, A, ZNXARI>(
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxExtractDigitAddMul
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFinalStepInplace
+ ZnxNormalizeDigit,
{
match res_base2k == a_base2k {
true => vec_znx_normalize_inter_base2k::<R, A, ZNXARI>(res_base2k, res, res_offset, res_col, a, a_col, carry),
false => vec_znx_normalize_cross_base2k::<R, A, ZNXARI>(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry),
}
}
fn vec_znx_normalize_inter_base2k<R, A, ZNXARI>(
base2k: usize,
res: &mut R,
res_offset: i64,
res_col: usize,
a: &A,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStepInplace
+ ZnxNormalizeMiddleStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= 2 * res.n());
assert!(carry.len() >= vec_znx_normalize_tmp_bytes(res.n()) / size_of::<i64>());
assert_eq!(res.n(), a.n());
}
@@ -53,153 +81,323 @@ pub fn vec_znx_normalize<R, A, ZNXARI>(
let res_size: usize = res.size();
let a_size: usize = a.size();
let carry: &mut [i64] = &mut carry[..2 * n];
let (carry, _) = carry.split_at_mut(n);
if res_base2k == a_base2k {
if a_size > res_size {
for j in (res_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
}
}
let mut lsh: i64 = res_offset % base2k as i64;
let mut limbs_offset: i64 = res_offset / base2k as i64;
for j in (1..res_size).rev() {
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
// If res_offset is negative, makes it positive
// and corrects by adding an additional offset
// on the limbs.
if res_offset < 0 && lsh != 0 {
lsh = (lsh + base2k as i64) % (base2k as i64);
limbs_offset -= 1;
}
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
let lsh_pos: usize = lsh as usize;
let res_end: usize = (-limbs_offset).clamp(0, res_size as i64) as usize;
let res_start: usize = (a_size as i64 - limbs_offset).clamp(0, res_size as i64) as usize;
let a_end: usize = limbs_offset.clamp(0, a_size as i64) as usize;
let a_start: usize = (res_size as i64 + limbs_offset).clamp(0, a_size as i64) as usize;
let a_out_range: usize = a_size.saturating_sub(a_start);
// Computes the carry over the discarded limbs of a
for j in 0..a_out_range {
if j == 0 {
ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh_pos, a.at(a_col, a_size - j - 1), carry);
} else {
for j in (0..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
}
for j in a_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh_pos, a.at(a_col, a_size - j - 1), carry);
}
} else {
let (a_norm, carry) = carry.split_at_mut(n);
}
// Relevant limbs of res
let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size);
// If no limbs were discarded, initialize carry to zero
if a_out_range == 0 {
ZNXARI::znx_zero(carry);
}
// Relevant limbs of a
let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size);
// Zeroes bottom limbs that will not be interacted with
for j in res_start..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
// Get carry for limbs of a that have higher precision than res
for j in (a_min_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(a_base2k, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, 0, a.at(a_col, j), carry);
let mid_range: usize = a_start.saturating_sub(a_end);
// Regular normalization over the overlapping limbs of res and a.
for j in 0..mid_range {
ZNXARI::znx_normalize_middle_step(
base2k,
lsh_pos,
res.at_mut(res_col, res_start - j - 1),
a.at(a_col, a_start - j - 1),
carry,
);
}
// Propagates the carry over the non-overlapping limbs between res and a
for j in 0..res_end {
ZNXARI::znx_zero(res.at_mut(res_col, res_end - j - 1));
if j == res_end - 1 {
ZNXARI::znx_normalize_final_step_inplace(base2k, lsh_pos, res.at_mut(res_col, res_end - j - 1), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh_pos, res.at_mut(res_col, res_end - j - 1), carry);
}
}
}
#[allow(clippy::too_many_arguments)]
fn vec_znx_normalize_cross_base2k<R, A, ZNXARI>(
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a: &A,
a_base2k: usize,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxAddInplace
+ ZnxMulPowerOfTwoInplace
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxExtractDigitAddMul
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFinalStepInplace
+ ZnxNormalizeDigit,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= vec_znx_normalize_tmp_bytes(res.n()) / size_of::<i64>());
assert_eq!(res.n(), a.n());
}
let n: usize = res.n();
let res_size: usize = res.size();
let a_size: usize = a.size();
let (a_norm, carry) = carry.split_at_mut(n);
let (res_carry, a_carry) = carry[..2 * n].split_at_mut(n);
ZNXARI::znx_zero(res_carry);
// Total precision (in bits) that `a` and `res` can represent.
let a_tot_bits: usize = a_size * a_base2k;
let res_tot_bits: usize = res_size * res_base2k;
// Derive intra-limb shift and cross-limb offset.
let mut lsh: i64 = res_offset % a_base2k as i64;
let mut limbs_offset: i64 = res_offset / a_base2k as i64;
// If res_offset is negative, ensures it is positive
// and corrects by incrementing the cross-limb offset.
if res_offset < 0 && lsh != 0 {
lsh = (lsh + a_base2k as i64) % (a_base2k as i64);
limbs_offset -= 1;
}
let lsh_pos: usize = lsh as usize;
// Derive start/stop bit indexes of the overlap between `a` and `res` (after taking into account the offset)..
let res_end_bit: usize = (-limbs_offset * a_base2k as i64).clamp(0, res_tot_bits as i64) as usize; // Stop bit
let res_start_bit: usize = (a_tot_bits as i64 - limbs_offset * a_base2k as i64).clamp(0, res_tot_bits as i64) as usize; // Start bit
let a_end_bit: usize = (limbs_offset * a_base2k as i64).clamp(0, a_tot_bits as i64) as usize; // Stop bit
let a_start_bit: usize = (res_tot_bits as i64 + limbs_offset * a_base2k as i64).clamp(0, a_tot_bits as i64) as usize; // Start bit
// Convert bits to limb indexes.
let res_end: usize = res_end_bit / res_base2k;
let res_start: usize = res_start_bit.div_ceil(res_base2k);
let a_end: usize = a_end_bit / a_base2k;
let a_start: usize = a_start_bit.div_ceil(a_base2k);
// Zero all limbs of `res`. Unlike the simple case
// where `res_base2k` is equal to `a_base2k`, we also
// need to ensure that the limbs starting from `res_end`
// are zero.
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
// Case where offset is positive and greater or equal
// to the precision of a.
if res_start == 0 {
return;
}
// Limbs of `a` that have a greater precision than `res`.
let a_out_range: usize = a_size.saturating_sub(a_start);
for j in 0..a_out_range {
if j == 0 {
ZNXARI::znx_normalize_first_step_carry_only(a_base2k, lsh_pos, a.at(a_col, a_size - j - 1), a_carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, lsh_pos, a.at(a_col, a_size - j - 1), a_carry);
}
}
// Zero carry if the above loop didn't trigger.
if a_out_range == 0 {
ZNXARI::znx_zero(a_carry);
}
// How much is left to accumulate to fill a limb of `res`.
let mut res_acc_left: usize = res_base2k;
// Starting limb of `res`.
let mut res_limb: usize = res_start - 1;
// How many limbs of `a` overlap with `res` (after taking into account the offset).
let mid_range: usize = a_start.saturating_sub(a_end);
// Regular normalization over the overlapping limbs of res and a.
'outer: for j in 0..mid_range {
let a_limb: usize = a_start - j - 1;
// Current res & a limbs
let a_slice: &[i64] = a.at(a_col, a_limb);
// Trackers: wow much of a_norm is left to
// be flushed on res.
let mut a_take_left: usize = a_base2k;
// Normalizes the j-th limb of a and store the results into `a_norm``.
// This step is required to avoid overflow in the next step,
// which assumes that |a| is bounded by 2^{a_base2k -1} (i.e. normalized).
ZNXARI::znx_normalize_middle_step(a_base2k, lsh_pos, a_norm, a_slice, a_carry);
// In the first iteration we need to match the precision `res` and `a`.
if j == 0 {
// Case where `a` has more precision than `res` (after taking into account the offset)
//
// For example:
//
// a: [x x x x x][x x x x x][x x x x x][x x x x x]
// res: [x x x x x x][x x x x x x][x x x x x x]
if !(a_tot_bits - a_start_bit).is_multiple_of(a_base2k) {
let take: usize = (a_tot_bits - a_start_bit) % a_base2k;
ZNXARI::znx_mul_power_of_two_inplace(-(take as i64), a_norm);
a_take_left -= take;
// Case where `res` has more precision than `a` (after taking into account the offset)
//
// For example:
//
// a: [x x x x x][x x x x x][x x x x x][x x x x x]
// res: [x x x x x x][x x x x x x][x x x x x x]
} else if !(res_tot_bits - res_start_bit).is_multiple_of(res_base2k) {
res_acc_left -= (res_tot_bits - res_start_bit) % res_base2k;
}
}
if a_min_size == a_size {
ZNXARI::znx_zero(carry);
}
// Extract bits of `a_norm` and accumulates them on res[res_limb] until
// res_base2k bits have been accumulated or until all bits of `a` are
// extracted.
'inner: loop {
// Current limb of res
let res_slice: &mut [i64] = res.at_mut(res_col, res_limb);
// Maximum relevant precision of a
let a_prec: usize = a_min_size * a_base2k;
// We can take at most a_base2k bits
// but not more than what is left on a_norm or what is left to
// fully populate the current limb of res.
let a_take: usize = a_base2k.min(a_take_left).min(res_acc_left);
// Maximum relevant precision of res
let res_prec: usize = res_min_size * res_base2k;
// Res limb index
let mut res_idx: usize = res_min_size - 1;
// Trackers: wow much of res is left to be populated
// for the current limb.
let mut res_left: usize = res_base2k;
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
for j in (0..a_min_size).rev() {
// Trackers: wow much of a_norm is left to
// be flushed on res.
let mut a_left: usize = a_base2k;
// Normalizes the j-th limb of a and store the results into a_norm.
// This step is required to avoid overflow in the next step,
// which assumes that |a| is bounded by 2^{a_base2k -1}.
if j != 0 {
ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
if a_take != 0 {
// Extract `a_take` bits from a_norm and accumulates them on `res_slice`.
let scale: usize = res_base2k - res_acc_left;
ZNXARI::znx_extract_digit_addmul(a_take, scale, res_slice, a_norm);
a_take_left -= a_take;
res_acc_left -= a_take;
}
// In the first iteration we need to match the precision of the input/output.
// If a_min_size * a_base2k > res_min_size * res_base2k
// then divround a_norm by the difference of precision and
// acts like if a_norm has already been partially consummed.
// Else acts like if res has been already populated
// by the difference.
if j == a_min_size - 1 {
if a_prec > res_prec {
ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm);
a_left -= a_prec - res_prec;
} else if res_prec > a_prec {
res_left -= res_prec - a_prec;
}
}
// If either:
// * At least `res_base2k` bits have been accumulated
// * We have reached the last limb of a
// Then: Flushes them onto res
if res_acc_left == 0 || a_limb == 0 {
// This case happens only if `res_offset` is negative.
// If `res_offset` is negative, we need to apply the offset BEFORE
// the normalization to ensure the `res-offset` overflowing bits of `a`
// are in the MSB of `res` instead of being discarded.
if a_limb == 0 && a_take_left == 0 {
// TODO: prove no overflow can happen here (should not intuitively)
ZNXARI::znx_add_inplace(a_carry, a_norm);
// Flushes a into res
loop {
// Selects the maximum amount of a that can be flushed
let a_take: usize = a_base2k.min(a_left).min(res_left);
// Output limb
let res_slice: &mut [i64] = res.at_mut(res_col, res_idx);
// Scaling of the value to flush
let lsh: usize = res_base2k - res_left;
// Extract the bits to flush on the output and updates
// a_norm accordingly.
ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm);
// Updates the trackers
a_left -= a_take;
res_left -= a_take;
// If the current limb of res is full,
// then normalizes this limb and adds
// the carry on a_norm.
if res_left == 0 {
// Updates tracker
res_left += res_base2k;
// Normalizes res and propagates the carry on a.
ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm);
// If we reached the last limb of res breaks,
// but we might rerun the above loop if the
// base2k of a is much smaller than the base2k
// of res.
if res_idx == 0 {
ZNXARI::znx_add_inplace(carry, a_norm);
break;
// Usual case where for example
// a: [ overflow ][x x x x x][x x x x x][x x x x x][x x x x x]
// res: [x x x x x x][x x x x x x][x x x x x x][x x x x x x]
//
// where [overflow] are the overflowing bits of `a` (note that they are not a limb, but
// stored in a[0] & carry from a[1]) that are moved into the MSB of `res` due to the
// negative offset.
//
// In this case we populate what is left of `res_acc_left` using `a_carry`
//
// TODO: see if this can be simplified (e.g. just add).
if res_acc_left != 0 {
let scale: usize = res_base2k - res_acc_left;
ZNXARI::znx_extract_digit_addmul(res_acc_left, scale, res_slice, a_carry);
}
// Else updates the limb index of res.
res_idx -= 1
ZNXARI::znx_normalize_middle_step_inplace(res_base2k, 0, res_slice, res_carry);
// Previous step might not consume all bits of a_carry
// TODO: prove no overflow can happen here
ZNXARI::znx_add_inplace(res_carry, a_carry);
// We are done, so breaks out of the loop (yes we are at a[0], but
// this avoids possible over/under flows of tracking variables)
break 'outer;
}
// If a_norm is exhausted, breaks the loop.
if a_left == 0 {
ZNXARI::znx_add_inplace(carry, a_norm);
break;
// If we reached the last limb of res
if res_limb == 0 {
break 'outer;
}
res_acc_left += res_base2k;
res_limb -= 1;
}
// If a_norm is exhausted, breaks the inner loop.
if a_take_left == 0 {
ZNXARI::znx_add_inplace(a_carry, a_norm);
break 'inner;
}
}
}
// This case will happen if offset is negative.
if res_end != 0 {
// If there are no overlapping limbs between `res` and `a`
// (can happen if offset is negative), then we propagate the
// carry of `a` on res. Note that the carry of `a` can be
// greater than the precision of res.
//
// For example with offset = -8:
// a carry a[0] a[1] a[2] a[3]
// a: [---------------------- ][x x x][x x x][x x x][x x x]
// b: [x x x x][x x x x ]
// res[0] res[1]
//
// If there are overlapping limbs between `res` and `a`,
// we can use `res_carry`, which contains the carry of propagating
// the shifted reconstruction of `a` in `res_base2k` along with
// the carry of a[0].
let carry_to_use = if a_start == a_end { a_carry } else { res_carry };
for j in 0..res_end {
if j == res_end - 1 {
ZNXARI::znx_normalize_final_step_inplace(res_base2k, 0, res.at_mut(res_col, res_end - j - 1), carry_to_use);
} else {
ZNXARI::znx_normalize_middle_step_inplace(res_base2k, 0, res.at_mut(res_col, res_end - j - 1), carry_to_use);
}
}
}
@@ -229,6 +427,191 @@ where
}
}
#[test]
fn test_vec_znx_normalize_cross_base2k() {
let n: usize = 8;
let mut carry: Vec<i64> = vec![0i64; vec_znx_normalize_tmp_bytes(n) / size_of::<i64>()];
use crate::reference::znx::ZnxRef;
use rug::ops::SubAssignRound;
use rug::{Float, float::Round};
let prec: usize = 128;
for in_base2k in 1..=51 {
for out_base2k in 1..=51 {
for offset in [
-(prec as i64),
-(prec as i64 - 1),
-(prec as i64 - in_base2k as i64),
-(in_base2k as i64 + 1),
in_base2k as i64,
-(in_base2k as i64 - 1),
0,
(in_base2k as i64 - 1),
in_base2k as i64,
(in_base2k as i64 + 1),
(prec as i64 - in_base2k as i64),
(prec - 1) as i64,
prec as i64,
] {
let mut source: Source = Source::new([1u8; 32]);
let in_size: usize = prec.div_ceil(in_base2k);
let in_prec: u32 = (in_size * in_base2k) as u32;
// Ensures no loss of precision (mostly for testing purpose)
let out_size: usize = (in_prec as usize).div_ceil(out_base2k);
let out_prec: u32 = (out_size * out_base2k) as u32;
let min_prec: u32 = (in_size * in_base2k).min(out_size * out_base2k) as u32;
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, in_size);
want.fill_uniform(60, &mut source);
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, out_size);
have.fill_uniform(60, &mut source);
vec_znx_normalize_cross_base2k::<_, _, ZnxRef>(&mut have, out_base2k, offset, 0, &want, in_base2k, 0, &mut carry);
let mut data_have: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec + 60, 0)).collect();
let mut data_want: Vec<Float> = (0..n).map(|_| Float::with_val(in_prec + 60, 0)).collect();
have.decode_vec_float(out_base2k, 0, &mut data_have);
want.decode_vec_float(in_base2k, 0, &mut data_want);
let scale: Float = Float::with_val(out_prec + 60, Float::u_pow_u(2, offset.unsigned_abs() as u32));
if offset > 0 {
for x in &mut data_want {
*x *= &scale;
*x %= 1;
}
} else if offset < 0 {
for x in &mut data_want {
*x /= &scale;
*x %= 1;
}
} else {
for x in &mut data_want {
*x %= 1;
}
}
for x in &mut data_have {
if *x >= 0.5 {
*x -= 1;
} else if *x < -0.5 {
*x += 1;
}
}
for x in &mut data_want {
if *x >= 0.5 {
*x -= 1;
} else if *x < -0.5 {
*x += 1;
}
}
for i in 0..n {
//println!("i:{i:02} {} {}", data_want[i], data_have[i]);
let mut err: Float = data_have[i].clone();
err.sub_assign_round(&data_want[i], Round::Nearest);
err = err.abs();
let err_log2: f64 = err.clone().max(&Float::with_val(prec as u32, 1e-60)).log2().to_f64();
assert!(err_log2 <= -(min_prec as f64) + 1.0, "{} {}", err_log2, -(min_prec as f64))
}
}
}
}
}
#[test]
fn test_vec_znx_normalize_inter_base2k() {
let n: usize = 8;
let mut carry: Vec<i64> = vec![0i64; vec_znx_normalize_tmp_bytes(n) / size_of::<i64>()];
use crate::reference::znx::ZnxRef;
use rug::ops::SubAssignRound;
use rug::{Float, float::Round};
let mut source: Source = Source::new([1u8; 32]);
let prec: usize = 128;
let offset_range: i64 = prec as i64;
for base2k in 1..=51 {
for offset in (-offset_range..=offset_range).step_by(base2k + 1) {
let size: usize = prec.div_ceil(base2k);
let out_prec: u32 = (size * base2k) as u32;
// Fills "want" with uniform values
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, size);
want.fill_uniform(60, &mut source);
// Fills "have" with the shifted normalization of "want"
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, size);
have.fill_uniform(60, &mut source);
vec_znx_normalize_inter_base2k::<_, _, ZnxRef>(base2k, &mut have, offset, 0, &want, 0, &mut carry);
let mut data_have: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec + 60, 0)).collect();
let mut data_want: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec + 60, 0)).collect();
have.decode_vec_float(base2k, 0, &mut data_have);
want.decode_vec_float(base2k, 0, &mut data_want);
let scale: Float = Float::with_val(out_prec + 60, Float::u_pow_u(2, offset.unsigned_abs() as u32));
if offset > 0 {
for x in &mut data_want {
*x *= &scale;
*x %= 1;
}
} else if offset < 0 {
for x in &mut data_want {
*x /= &scale;
*x %= 1;
}
} else {
for x in &mut data_want {
*x %= 1;
}
}
for x in &mut data_have {
if *x >= 0.5 {
*x -= 1;
} else if *x < -0.5 {
*x += 1;
}
}
for x in &mut data_want {
if *x >= 0.5 {
*x -= 1;
} else if *x < -0.5 {
*x += 1;
}
}
for i in 0..n {
//println!("i:{i:02} {} {}", data_want[i], data_have[i]);
let mut err: Float = data_have[i].clone();
err.sub_assign_round(&data_want[i], Round::Nearest);
err = err.abs();
let err_log2: f64 = err.clone().max(&Float::with_val(prec as u32, 1e-60)).log2().to_f64();
assert!(err_log2 <= -(out_prec as f64), "{} {}", err_log2, -(out_prec as f64))
}
}
}
}
pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
@@ -261,10 +644,10 @@ where
res.fill_uniform(50, &mut source);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
let res_offset: i64 = 0;
move || {
for i in 0..cols {
module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
module.vec_znx_normalize(&mut res, base2k, res_offset, i, &a, base2k, i, scratch.borrow());
}
black_box(());
}
@@ -326,71 +709,3 @@ where
group.finish();
}
#[test]
fn test_vec_znx_normalize_conv() {
let n: usize = 8;
let mut carry: Vec<i64> = vec![0i64; 2 * n];
use crate::reference::znx::ZnxRef;
use rug::ops::SubAssignRound;
use rug::{Float, float::Round};
let mut source: Source = Source::new([1u8; 32]);
let prec: usize = 128;
let mut data: Vec<i128> = vec![0i128; n];
data.iter_mut().for_each(|x| *x = source.next_i128());
for start_base2k in 1..50 {
for end_base2k in 1..50 {
let end_size: usize = prec.div_ceil(end_base2k);
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
want.encode_vec_i128(end_base2k, 0, prec, &data);
vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry);
// Creates a temporary poly where encoding is in start_base2k
let mut tmp: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k));
tmp.encode_vec_i128(start_base2k, 0, prec, &data);
vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry);
let mut data_tmp: Vec<Float> = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect();
tmp.decode_vec_float(start_base2k, 0, &mut data_tmp);
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry);
let out_prec: u32 = (end_size * end_base2k) as u32;
let mut data_want: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec, 0)).collect();
let mut data_res: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec, 0)).collect();
have.decode_vec_float(end_base2k, 0, &mut data_want);
want.decode_vec_float(end_base2k, 0, &mut data_res);
for i in 0..n {
let mut err: Float = data_want[i].clone();
err.sub_assign_round(&data_res[i], Round::Nearest);
err = err.abs();
let err_log2: f64 = err
.clone()
.max(&Float::with_val(prec as u32, 1e-60))
.log2()
.to_f64();
assert!(
err_log2 <= -(out_prec as f64) + 1.,
"{} {}",
err_log2,
-(out_prec as f64) + 1.
)
}
}
}
}

View File

@@ -34,12 +34,7 @@ pub fn vec_znx_fill_normal_ref<R>(
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_fill_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
bound * scale,
source,
)
znx_fill_normal_f64_ref(res.at_mut(res_col, limb), sigma * scale, bound * scale, source)
}
pub fn vec_znx_add_normal_ref<R>(
@@ -62,10 +57,5 @@ pub fn vec_znx_add_normal_ref<R>(
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
bound * scale,
source,
)
znx_add_normal_f64_ref(res.at_mut(res_col, limb), sigma * scale, bound * scale, source)
}

View File

@@ -5,13 +5,10 @@ use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxLsh, VecZnxLshInplace, VecZnxRsh, VecZnxRshInplace},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::vec_znx_copy,
znx::{
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxZero,
},
reference::znx::{
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxZero,
},
source::Source,
};
@@ -54,10 +51,7 @@ where
(0..size - steps).for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * (j + steps));
ZNXARI::znx_copy(
&mut lhs[start + j * slice_size..end + j * slice_size],
&rhs[start..end],
);
ZNXARI::znx_copy(&mut lhs[start + j * slice_size..end + j * slice_size], &rhs[start..end]);
});
for j in size - steps..size {
@@ -65,16 +59,13 @@ where
}
}
// Inplace normalization with left shift of k % base2k
if !k.is_multiple_of(base2k) {
for j in (0..size - steps).rev() {
if j == size - steps - 1 {
ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
}
for j in (0..size - steps).rev() {
if j == size - steps - 1 {
ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
}
}
}
@@ -104,38 +95,13 @@ where
// Simply a left shifted normalization of limbs
// by k/base2k and intra-limb by base2k - k%base2k
if !k.is_multiple_of(base2k) {
for j in (0..min_size).rev() {
if j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
carry,
);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
carry,
);
}
}
} else {
// If k % base2k = 0, then this is simply a copy.
for j in (0..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
for j in (0..min_size).rev() {
if j == min_size - 1 {
ZNXARI::znx_normalize_first_step(base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), carry);
} else {
ZNXARI::znx_normalize_middle_step(base2k, k_rem, res.at_mut(res_col, j), a.at(a_col, j + steps), carry);
}
}
@@ -146,10 +112,10 @@ where
}
pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
2 * n * size_of::<i64>()
}
pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
@@ -163,76 +129,48 @@ where
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let n: usize = res.n();
let cols: usize = res.cols();
let size: usize = res.size();
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
return;
}
if steps >= size {
for j in 0..size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
let start: usize = n * res_col;
let end: usize = start + n;
let slice_size: usize = n * cols;
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
}
// All limbs of a that would fall outside of the limbs of res are discarded,
// but the carry still need to be computed.
(size - steps..size).rev().for_each(|j| {
if j == size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
}
});
let (carry, tmp) = tmp[..2 * n].split_at_mut(n);
// Continues with shifted normalization
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
let rhs_slice: &mut [i64] = &mut rhs[start..end];
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry);
});
let lsh: usize = (base2k - k_rem) % base2k;
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
// All limbs of a that would fall outside of the limbs of res are discarded,
// but the carry still need to be computed.
for j in 0..steps {
if j == 0 {
ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh, res.at(res_col, size - j - 1), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh, res.at(res_col, size - j - 1), carry);
}
} else {
// Shift by multiples of base2k
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
ZNXARI::znx_copy(
&mut rhs[start..end],
&lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end],
);
});
}
// Zeroes the top
(0..steps).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
// Continues with shifted normalization
for j in 0..size - steps {
ZNXARI::znx_copy(tmp, res.at(res_col, size - steps - j - 1));
ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh, tmp, carry);
ZNXARI::znx_copy(res.at_mut(res_col, size - j - 1), tmp);
}
// Propagates carry on the rest of the limbs of res
for j in 0..steps {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(base2k, lsh, res.at_mut(res_col, steps - j - 1), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh, res.at_mut(res_col, steps - j - 1), carry);
}
}
}
@@ -259,90 +197,59 @@ where
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
return;
}
if steps >= res_size {
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
}
// All limbs of a that are moved outside of the limbs of res are discarded,
// but the carry still need to be computed.
for j in (res_size..a_size + steps).rev() {
if j == a_size + steps - 1 {
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
}
let lsh: usize = (base2k - k_rem) % base2k; // 0 if k | base2k
let res_end: usize = res_size.min(steps);
let res_start: usize = res_size.min(a_size + steps);
let a_start: usize = a_size.min(res_size.saturating_sub(steps));
// All limbs of a that are moved outside of the limbs of res are discarded,
// but the carry still need to be computed.
let a_out_range: usize = a_size.saturating_sub(a_start);
for j in 0..a_out_range {
if j == 0 {
ZNXARI::znx_normalize_first_step_carry_only(base2k, lsh, a.at(a_col, a_size - j - 1), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(base2k, lsh, a.at(a_col, a_size - j - 1), carry);
}
}
// Avoids over flow of limbs of res
let min_size: usize = res_size.min(a_size + steps);
if a_out_range == 0 {
ZNXARI::znx_zero(carry);
}
// Zeroes lower limbs of res if a_size + steps < res_size
(min_size..res_size).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
// Zeroes lower limbs of res if a_size + steps < res_size
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
// Continues with shifted normalization
for j in (steps..min_size).rev() {
// Case if no limb of a was previously discarded
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
}
// Continues with shifted normalization
let mid_range: usize = res_start.saturating_sub(res_end);
for j in 0..mid_range {
ZNXARI::znx_normalize_middle_step(
base2k,
lsh,
res.at_mut(res_col, res_start - j - 1),
a.at(a_col, a_start - j - 1),
carry,
);
}
// Propagates carry on the rest of the limbs of res
for j in 0..res_end {
if j == res_end - 1 {
ZNXARI::znx_normalize_final_step_inplace(base2k, lsh, res.at_mut(res_col, res_end - j - 1), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(base2k, lsh, res.at_mut(res_col, res_end - j - 1), carry);
}
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
let min_size: usize = res_size.min(a_size + steps);
// Zeroes the top
(0..steps).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
// Shift a into res, up to the maximum
for j in (steps..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j - steps));
}
// Zeroes bottom if a_size + steps < res_size
(min_size..res_size).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
}
}
@@ -373,7 +280,7 @@ where
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(vec_znx_lsh_tmp_bytes(n));
// Fill a with random i64
a.fill_uniform(50, &mut source);
@@ -423,7 +330,7 @@ where
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(vec_znx_lsh_tmp_bytes(n));
// Fill a with random i64
a.fill_uniform(50, &mut source);
@@ -473,7 +380,7 @@ where
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(vec_znx_rsh_tmp_bytes(n));
// Fill a with random i64
a.fill_uniform(50, &mut source);
@@ -523,7 +430,7 @@ where
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(vec_znx_rsh_tmp_bytes(n));
// Fill a with random i64
a.fill_uniform(50, &mut source);
@@ -552,8 +459,8 @@ mod tests {
layouts::{FillUniform, VecZnx, ZnxView},
reference::{
vec_znx::{
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
vec_znx_sub_inplace,
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_normalize_inplace, vec_znx_rsh,
vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_sub_inplace,
},
znx::ZnxRef,
},
@@ -572,7 +479,7 @@ mod tests {
let mut source: Source = Source::new([0u8; 32]);
let mut carry: Vec<i64> = vec![0i64; n];
let mut carry: Vec<i64> = vec![0i64; vec_znx_lsh_tmp_bytes(n) / size_of::<i64>()];
let base2k: usize = 50;
@@ -604,7 +511,7 @@ mod tests {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut carry: Vec<i64> = vec![0i64; n];
let mut carry: Vec<i64> = vec![0i64; vec_znx_rsh_tmp_bytes(n) / size_of::<i64>()];
let base2k: usize = 50;

View File

@@ -22,10 +22,7 @@ where
{
assert_eq!(tmp.len(), a.n());
assert!(
_n_out < _n_in,
"invalid a: output ring degree should be smaller"
);
assert!(_n_out < _n_in, "invalid a: output ring degree should be smaller");
res[1..].iter_mut().for_each(|bi| {
assert_eq!(

View File

@@ -12,10 +12,6 @@ pub fn znx_automorphism_ref(p: i64, res: &mut [i64], a: &[i64]) {
res[0] = a[0];
for ai in a.iter().take(n).skip(1) {
k = (k + p_2n) & mask;
if k < n {
res[k] = *ai
} else {
res[k - n] = -*ai
}
if k < n { res[k] = *ai } else { res[k - n] = -*ai }
}
}

View File

@@ -33,9 +33,9 @@ pub fn znx_normalize_first_step_carry_only_ref(base2k: usize, lsh: usize, x: &[i
*c = get_carry_i64(base2k, *x, get_digit_i64(base2k, *x));
});
} else {
let basek_lsh: usize = base2k - lsh;
let base2k_lsh: usize = base2k - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
*c = get_carry_i64(basek_lsh, *x, get_digit_i64(basek_lsh, *x));
*c = get_carry_i64(base2k_lsh, *x, get_digit_i64(base2k_lsh, *x));
});
}
}
@@ -55,10 +55,10 @@ pub fn znx_normalize_first_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [
*x = digit;
});
} else {
let basek_lsh: usize = base2k - lsh;
let base2k_lsh: usize = base2k - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit_i64(basek_lsh, *x);
*c = get_carry_i64(basek_lsh, *x, digit);
let digit: i64 = get_digit_i64(base2k_lsh, *x);
*c = get_carry_i64(base2k_lsh, *x, digit);
*x = digit << lsh;
});
}
@@ -80,10 +80,10 @@ pub fn znx_normalize_first_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a:
*x = digit;
});
} else {
let basek_lsh: usize = base2k - lsh;
let base2k_lsh: usize = base2k - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit_i64(basek_lsh, *a);
*c = get_carry_i64(basek_lsh, *a, digit);
let digit: i64 = get_digit_i64(base2k_lsh, *a);
*c = get_carry_i64(base2k_lsh, *a, digit);
*x = digit << lsh;
});
}
@@ -104,10 +104,10 @@ pub fn znx_normalize_middle_step_carry_only_ref(base2k: usize, lsh: usize, x: &[
*c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c));
});
} else {
let basek_lsh: usize = base2k - lsh;
let base2k_lsh: usize = base2k - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit_i64(basek_lsh, *x);
let carry: i64 = get_carry_i64(basek_lsh, *x, digit);
let digit: i64 = get_digit_i64(base2k_lsh, *x);
let carry: i64 = get_carry_i64(base2k_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c));
});
@@ -131,10 +131,10 @@ pub fn znx_normalize_middle_step_inplace_ref(base2k: usize, lsh: usize, x: &mut
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = base2k - lsh;
let base2k_lsh: usize = base2k - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit_i64(basek_lsh, *x);
let carry: i64 = get_carry_i64(basek_lsh, *x, digit);
let digit: i64 = get_digit_i64(base2k_lsh, *x);
let carry: i64 = get_carry_i64(base2k_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
@@ -178,10 +178,10 @@ pub fn znx_normalize_middle_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = base2k - lsh;
let base2k_lsh: usize = base2k - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit_i64(basek_lsh, *a);
let carry: i64 = get_carry_i64(basek_lsh, *a, digit);
let digit: i64 = get_digit_i64(base2k_lsh, *a);
let carry: i64 = get_carry_i64(base2k_lsh, *a, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
@@ -202,9 +202,9 @@ pub fn znx_normalize_final_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [
*x = get_digit_i64(base2k, get_digit_i64(base2k, *x) + *c);
});
} else {
let basek_lsh: usize = base2k - lsh;
let base2k_lsh: usize = base2k - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
*x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *x) << lsh) + *c);
*x = get_digit_i64(base2k, (get_digit_i64(base2k_lsh, *x) << lsh) + *c);
});
}
}
@@ -221,9 +221,9 @@ pub fn znx_normalize_final_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a:
*x = get_digit_i64(base2k, get_digit_i64(base2k, *a) + *c);
});
} else {
let basek_lsh: usize = base2k - lsh;
let base2k_lsh: usize = base2k - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
*x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *a) << lsh) + *c);
*x = get_digit_i64(base2k, (get_digit_i64(base2k_lsh, *a) << lsh) + *c);
});
}
}

View File

@@ -1,19 +1,80 @@
use rand::RngCore;
use crate::{
api::{
BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
CnvPVecAlloc, Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxAdd,
VecZnxBigAlloc, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA,
VecZnxNormalizeInplace,
},
layouts::{
Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView,
ZnxViewMut, ZnxZero,
Backend, CnvPVecL, CnvPVecR, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef,
ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
pub fn test_bivariate_tensoring<M, BE: Backend>(module: &M)
pub fn test_convolution_by_const<M, BE: Backend>(module: &M)
where
M: ModuleN + Convolution<BE> + VecZnxBigNormalize<BE> + VecZnxNormalizeInplace<BE> + VecZnxBigAlloc<BE>,
Scratch<BE>: ScratchTakeBasic,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let mut source: Source = Source::new([0u8; 32]);
let base2k: usize = 12;
let a_cols: usize = 2;
let a_size: usize = 15;
let b_size: usize = 15;
let res_size: usize = a_size + b_size;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, b_size);
let mut res_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, res_size);
a.fill_uniform(base2k, &mut source);
let mut b_const = vec![0i64; b_size];
let mask = (1 << base2k) - 1;
for (j, x) in b_const[..1].iter_mut().enumerate() {
let r = source.next_u64() & mask;
*x = ((r << (64 - base2k)) as i64) >> (64 - base2k);
b.at_mut(0, j)[0] = *x
}
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.cnv_by_const_apply_tmp_bytes(res_size, 0, a_size, b_size));
for a_col in 0..a.cols() {
for offset in 0..res_size {
module.cnv_by_const_apply(&mut res_big, offset, 0, &a, a_col, &b_const, scratch.borrow());
module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow());
bivariate_convolution_naive(
module,
base2k,
(offset + 1) as i64,
&mut res_want,
0,
&a,
a_col,
&b,
0,
scratch.borrow(),
);
assert_eq!(res_want, res_have);
}
}
}
pub fn test_convolution<M, BE: Backend>(module: &M)
where
M: ModuleN
+ BivariateTensoring<BE>
+ Convolution<BE>
+ CnvPVecAlloc<BE>
+ VecZnxDftAlloc<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyTmpA<BE>
@@ -27,56 +88,199 @@ where
let base2k: usize = 12;
let a_cols: usize = 3;
let b_cols: usize = 3;
let a_size: usize = 3;
let b_size: usize = 3;
let c_cols: usize = a_cols + b_cols - 1;
let c_size: usize = a_size + b_size;
let a_cols: usize = 2;
let b_cols: usize = 2;
let a_size: usize = 15;
let b_size: usize = 15;
let res_size: usize = a_size + b_size;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), b_cols, b_size);
let mut c_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
let mut c_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
let mut c_have_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(c_cols, c_size);
let mut c_have_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(c_cols, c_size);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.convolution_tmp_bytes(b_size));
let mut res_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(1, res_size);
let mut res_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, res_size);
a.fill_uniform(base2k, &mut source);
b.fill_uniform(base2k, &mut source);
let mut b_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(b_cols, b_size);
for i in 0..b.cols() {
module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i);
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(a_cols, a_size);
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(b_cols, b_size);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
module
.cnv_apply_dft_tmp_bytes(res_size, 0, a_size, b_size)
.max(module.cnv_prepare_left_tmp_bytes(res_size, a_size))
.max(module.cnv_prepare_right_tmp_bytes(res_size, b_size)),
);
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
for a_col in 0..a.cols() {
for b_col in 0..b.cols() {
for offset in 0..res_size {
module.cnv_apply_dft(&mut res_dft, offset, 0, &a_prep, a_col, &b_prep, b_col, scratch.borrow());
module.vec_znx_idft_apply_tmpa(&mut res_big, 0, &mut res_dft, 0);
module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow());
bivariate_convolution_naive(
module,
base2k,
(offset + 1) as i64,
&mut res_want,
0,
&a,
a_col,
&b,
b_col,
scratch.borrow(),
);
assert_eq!(res_want, res_have);
}
}
}
}
pub fn test_convolution_pairwise<M, BE: Backend>(module: &M)
where
M: ModuleN
+ Convolution<BE>
+ CnvPVecAlloc<BE>
+ VecZnxDftAlloc<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalizeInplace<BE>
+ VecZnxBigAlloc<BE>
+ VecZnxAdd
+ VecZnxCopy,
Scratch<BE>: ScratchTakeBasic,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
let mut source: Source = Source::new([0u8; 32]);
let base2k: usize = 12;
let cols = 2;
let a_size: usize = 15;
let b_size: usize = 15;
let res_size: usize = a_size + b_size;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, a_size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), cols, b_size);
let mut tmp_a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, a_size);
let mut tmp_b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, b_size);
let mut res_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), 1, res_size);
let mut res_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(1, res_size);
let mut res_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, res_size);
a.fill_uniform(base2k, &mut source);
b.fill_uniform(base2k, &mut source);
let mut a_prep: CnvPVecL<Vec<u8>, BE> = module.cnv_pvec_left_alloc(cols, a_size);
let mut b_prep: CnvPVecR<Vec<u8>, BE> = module.cnv_pvec_right_alloc(cols, b_size);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
module
.cnv_pairwise_apply_dft_tmp_bytes(res_size, 0, a_size, b_size)
.max(module.cnv_prepare_left_tmp_bytes(res_size, a_size))
.max(module.cnv_prepare_right_tmp_bytes(res_size, b_size)),
);
module.cnv_prepare_left(&mut a_prep, &a, scratch.borrow());
module.cnv_prepare_right(&mut b_prep, &b, scratch.borrow());
for col_i in 0..cols {
for col_j in 0..cols {
for offset in 0..res_size {
module.cnv_pairwise_apply_dft(&mut res_dft, offset, 0, &a_prep, &b_prep, col_i, col_j, scratch.borrow());
module.vec_znx_idft_apply_tmpa(&mut res_big, 0, &mut res_dft, 0);
module.vec_znx_big_normalize(&mut res_have, base2k, 0, 0, &res_big, base2k, 0, scratch.borrow());
if col_i != col_j {
module.vec_znx_add(&mut tmp_a, 0, &a, col_i, &a, col_j);
module.vec_znx_add(&mut tmp_b, 0, &b, col_i, &b, col_j);
} else {
module.vec_znx_copy(&mut tmp_a, 0, &a, col_i);
module.vec_znx_copy(&mut tmp_b, 0, &b, col_j);
}
bivariate_convolution_naive(
module,
base2k,
(offset + 1) as i64,
&mut res_want,
0,
&tmp_a,
0,
&tmp_b,
0,
scratch.borrow(),
);
assert_eq!(res_want, res_have);
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn bivariate_convolution_naive<R, A, B, M, BE: Backend>(
module: &M,
base2k: usize,
k: i64,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
scratch: &mut Scratch<BE>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
M: VecZnxNormalizeInplace<BE>,
Scratch<BE>: TakeSlice,
{
let res: &mut VecZnx<&mut [u8]> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let b: &VecZnx<&[u8]> = &b.to_ref();
for j in 0..res.size() {
res.zero_at(res_col, j);
}
for mut k in 0..(2 * c_size + 1) as i64 {
k -= c_size as i64;
for a_limb in 0..a.size() {
for b_limb in 0..b.size() {
let res_scale_abs = k.unsigned_abs() as usize;
module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
let mut res_limb: usize = a_limb + b_limb + 1;
for i in 0..c_cols {
module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i);
if k <= 0 {
res_limb += res_scale_abs;
if res_limb < res.size() {
negacyclic_convolution_naive_add(res.at_mut(res_col, res_limb), a.at(a_col, a_limb), b.at(b_col, b_limb));
}
} else if res_limb >= res_scale_abs {
res_limb -= res_scale_abs;
if res_limb < res.size() {
negacyclic_convolution_naive_add(res.at_mut(res_col, res_limb), a.at(a_col, a_limb), b.at(b_col, b_limb));
}
}
}
for i in 0..c_cols {
module.vec_znx_big_normalize(
base2k,
&mut c_have,
i,
base2k,
&c_have_big,
i,
scratch.borrow(),
);
}
bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
assert_eq!(c_want, c_have);
}
module.vec_znx_normalize_inplace(base2k, res, res_col, scratch);
}
fn bivariate_tensoring_naive<R, A, B, M, BE: Backend>(
@@ -154,3 +358,18 @@ fn negacyclic_convolution_naive_add(res: &mut [i64], a: &[i64], b: &[i64]) {
}
}
}
fn negacyclic_convolution_naive(res: &mut [i64], a: &[i64], b: &[i64]) {
let n: usize = res.len();
res.fill(0);
for i in 0..n {
let ai: i64 = a[i];
let lim: usize = n - i;
for j in 0..lim {
res[i + j] += ai * b[j];
}
for j in lim..n {
res[i + j - n] -= ai * b[j];
}
}
}

View File

@@ -29,10 +29,7 @@ where
receiver.read_from(&mut reader).expect("read_from failed");
// Ensure serialization round-trip correctness
assert_eq!(
&original, &receiver,
"Deserialized object does not match the original"
);
assert_eq!(&original, &receiver, "Deserialized object does not match the original");
}
#[test]

View File

@@ -90,24 +90,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -212,24 +196,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -339,24 +307,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -447,24 +399,8 @@ pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
module_ref.vec_znx_big_normalize(&mut res_ref, base2k, 0, j, &res_big_ref, base2k, j, scratch_ref.borrow());
module_test.vec_znx_big_normalize(&mut res_test, base2k, 0, j, &res_big_test, base2k, j, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);

View File

@@ -7,8 +7,9 @@ use crate::{
VecZnxFillNormal, VecZnxFillUniform, VecZnxLsh, VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings,
VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes,
VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes,
VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing,
VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace,
VecZnxSwitchRing,
},
layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::znx_copy_ref,
@@ -341,10 +342,7 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_merge_rings_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: [VecZnx<Vec<u8>>; 2] = [
VecZnx::alloc(n >> 1, cols, a_size),
VecZnx::alloc(n >> 1, cols, a_size),
];
let mut a: [VecZnx<Vec<u8>>; 2] = [VecZnx::alloc(n >> 1, cols, a_size), VecZnx::alloc(n >> 1, cols, a_size)];
a.iter_mut().for_each(|ai| {
ai.fill_uniform(base2k, &mut source);
@@ -549,26 +547,20 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
for res_offset in -(base2k as i64)..=(base2k as i64) {
// Set d to garbage
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize(base2k, &mut res_ref, i, base2k, &a, i, scratch_ref.borrow());
module_test.vec_znx_normalize(
base2k,
&mut res_test,
i,
base2k,
&a,
i,
scratch_test.borrow(),
);
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize(&mut res_ref, base2k, res_offset, i, &a, base2k, i, scratch_ref.borrow());
module_test.vec_znx_normalize(&mut res_test, base2k, res_offset, i, &a, base2k, i, scratch_test.borrow());
}
assert_eq!(a.digest_u64(), a_digest);
assert_eq!(res_ref, res_test);
}
assert_eq!(a.digest_u64(), a_digest);
assert_eq!(res_ref, res_test);
}
}
}
@@ -718,10 +710,7 @@ where
})
} else {
let std: f64 = a.stats(base2k, col_i).std();
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={std} ~!= {one_12_sqrt}",
);
assert!((std - one_12_sqrt).abs() < 0.01, "std={std} ~!= {one_12_sqrt}",);
}
})
});
@@ -783,11 +772,7 @@ where
})
} else {
let std: f64 = a.stats(base2k, col_i).std() * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={std} ~!= {}",
sigma * sqrt2
);
assert!((std - sigma * sqrt2).abs() < 0.1, "std={std} ~!= {}", sigma * sqrt2);
}
})
});
@@ -872,9 +857,9 @@ where
pub fn test_vec_znx_rsh<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRsh<BR> + VecZnxLshTmpBytes,
Module<BR>: VecZnxRsh<BR> + VecZnxRshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: VecZnxRsh<BT> + VecZnxLshTmpBytes,
Module<BT>: VecZnxRsh<BT> + VecZnxRshTmpBytes,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
@@ -882,8 +867,8 @@ where
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
@@ -914,9 +899,9 @@ where
pub fn test_vec_znx_rsh_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRshInplace<BR> + VecZnxLshTmpBytes,
Module<BR>: VecZnxRshInplace<BR> + VecZnxRshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
Module<BT>: VecZnxRshInplace<BT> + VecZnxLshTmpBytes,
Module<BT>: VecZnxRshInplace<BT> + VecZnxRshTmpBytes,
ScratchOwned<BT>: ScratchOwnedAlloc<BT> + ScratchOwnedBorrow<BT>,
{
assert_eq!(module_ref.n(), module_test.n());
@@ -924,8 +909,8 @@ where
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_lsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
let mut scratch_ref: ScratchOwned<BR> = ScratchOwned::alloc(module_ref.vec_znx_rsh_tmp_bytes());
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_rsh_tmp_bytes());
for res_size in [1, 2, 3, 4] {
for k in 0..base2k * res_size {
@@ -966,15 +951,11 @@ where
let a_digest = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: [VecZnx<Vec<u8>>; 2] = [
VecZnx::alloc(n >> 1, cols, res_size),
VecZnx::alloc(n >> 1, cols, res_size),
];
let mut res_ref: [VecZnx<Vec<u8>>; 2] =
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
let mut res_test: [VecZnx<Vec<u8>>; 2] = [
VecZnx::alloc(n >> 1, cols, res_size),
VecZnx::alloc(n >> 1, cols, res_size),
];
let mut res_test: [VecZnx<Vec<u8>>; 2] =
[VecZnx::alloc(n >> 1, cols, res_size), VecZnx::alloc(n >> 1, cols, res_size)];
res_ref.iter_mut().for_each(|ri| {
ri.fill_uniform(base2k, &mut source);

View File

@@ -93,20 +93,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -188,20 +190,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -279,20 +283,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -367,20 +373,22 @@ pub fn test_vec_znx_big_add_small_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -459,20 +467,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -546,20 +556,22 @@ pub fn test_vec_znx_big_automorphism_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -631,20 +643,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -709,20 +723,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -782,36 +798,40 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
source.fill_bytes(res_ref.data_mut());
source.fill_bytes(res_test.data_mut());
for res_offset in -(base2k as i64)..=(base2k as i64) {
// Set d to garbage
source.fill_bytes(res_ref.data_mut());
source.fill_bytes(res_test.data_mut());
// Reference
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&a_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&a_test,
j,
scratch_test.borrow(),
);
// Reference
for j in 0..cols {
module_ref.vec_znx_big_normalize(
&mut res_ref,
base2k,
res_offset,
j,
&a_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
&mut res_test,
base2k,
res_offset,
j,
&a_test,
base2k,
j,
scratch_test.borrow(),
);
}
assert_eq!(a_ref.digest_u64(), a_ref_digest);
assert_eq!(a_test.digest_u64(), a_test_digest);
assert_eq!(res_ref, res_test);
}
assert_eq!(a_ref.digest_u64(), a_ref_digest);
assert_eq!(a_test.digest_u64(), a_test_digest);
assert_eq!(res_ref, res_test);
}
}
}
@@ -891,20 +911,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -986,20 +1008,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1083,20 +1107,22 @@ pub fn test_vec_znx_big_sub_negate_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1180,20 +1206,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1278,20 +1306,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1366,20 +1396,22 @@ pub fn test_vec_znx_big_sub_small_a_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -1427,55 +1459,59 @@ pub fn test_vec_znx_big_sub_small_b_inplace<BR: Backend, BT: Backend>(
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(base2k, &mut source);
for res_offset in -(base2k as i64)..=(base2k as i64) {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
for j in 0..cols {
module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j);
module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j);
for j in 0..cols {
module_ref.vec_znx_big_from_small(&mut res_big_ref, j, &res, j);
module_test.vec_znx_big_from_small(&mut res_big_test, j, &res, j);
}
for i in 0..cols {
module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i);
module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
&mut res_small_ref,
base2k,
res_offset,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
&mut res_small_test,
base2k,
res_offset,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
for i in 0..cols {
module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i);
module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
let mut res_small_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_small_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let res_ref_digest: u64 = res_big_ref.digest_u64();
let res_test_digest: u64 = res_big_test.digest_u64();
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(res_big_ref.digest_u64(), res_ref_digest);
assert_eq!(res_big_test.digest_u64(), res_test_digest);
assert_eq!(res_small_ref, res_small_test);
}
}
}

View File

@@ -102,20 +102,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -208,20 +210,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -311,20 +315,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -392,13 +398,7 @@ where
for j in 0..cols {
module_ref.vec_znx_idft_apply(&mut res_big_ref, j, &res_dft_ref, j, scratch_ref.borrow());
module_test.vec_znx_idft_apply(
&mut res_big_test,
j,
&res_dft_test,
j,
scratch_test.borrow(),
);
module_test.vec_znx_idft_apply(&mut res_big_test, j, &res_dft_test, j, scratch_test.borrow());
}
assert_eq!(res_dft_ref.digest_u64(), res_dft_ref_digest);
@@ -412,20 +412,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -502,20 +504,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -589,20 +593,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -709,20 +715,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -815,20 +823,22 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -923,20 +933,22 @@ pub fn test_vec_znx_dft_sub_negate_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);

View File

@@ -90,20 +90,22 @@ where
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -205,18 +207,8 @@ where
source.fill_bytes(res_dft_ref.data_mut());
source.fill_bytes(res_dft_test.data_mut());
module_ref.vmp_apply_dft_to_dft(
&mut res_dft_ref,
&a_dft_ref,
&pmat_ref,
scratch_ref.borrow(),
);
module_test.vmp_apply_dft_to_dft(
&mut res_dft_test,
&a_dft_test,
&pmat_test,
scratch_test.borrow(),
);
module_ref.vmp_apply_dft_to_dft(&mut res_dft_ref, &a_dft_ref, &pmat_ref, scratch_ref.borrow());
module_test.vmp_apply_dft_to_dft(&mut res_dft_test, &a_dft_test, &pmat_test, scratch_test.borrow());
let res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_idft_apply_consume(res_dft_ref);
let res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_idft_apply_consume(res_dft_test);
@@ -229,20 +221,22 @@ where
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);
@@ -379,20 +373,22 @@ where
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
base2k,
&mut res_small_ref,
j,
base2k,
0,
j,
&res_big_ref,
base2k,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_small_test,
j,
base2k,
0,
j,
&res_big_test,
base2k,
j,
scratch_test.borrow(),
);