Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -20,7 +20,7 @@ A `scalar_znx` is a front-end generic and backend agnostic type that stores a si
#### VecZnx
A `vec_znx` is a front-end generic and backend agnostic type that stores a vector of small polynomials (i.e. a vector of scalars). Each polynomial is a `limb` that provides an additional `basek`-bits of precision in the Torus. For example a `vec_znx` with `n`=1024 `basek`=2 with 3 limbs can store 1024 coefficients with 36 bits of precision in the torus. In practice this type is used for LWE and GLWE ciphertexts/plaintexts.
A `vec_znx` is a front-end generic and backend agnostic type that stores a vector of small polynomials (i.e. a vector of scalars). Each polynomial is a `limb` that provides an additional `base2k`-bits of precision in the Torus. For example a `vec_znx` with `n`=1024 `base2k`=2 with 3 limbs can store 1024 coefficients with 36 bits of precision in the torus. In practice this type is used for LWE and GLWE ciphertexts/plaintexts.
#### VecZnxDft

View File

@@ -98,10 +98,3 @@ pub trait TakeMatZnx {
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into the template's type and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeLike<'a, B: Backend, T> {
type Output;
fn take_like(&'a mut self, template: &T) -> (Self::Output, &'a mut Self);
}

View File

@@ -9,16 +9,25 @@ pub trait VecZnxNormalizeTmpBytes {
}
pub trait VecZnxNormalize<B: Backend> {
#[allow(clippy::too_many_arguments)]
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
fn vec_znx_normalize<R, A>(
&self,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxNormalizeInplace<B: Backend> {
/// Normalizes the selected column of `a`.
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_normalize_inplace<A>(&self, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
@@ -67,21 +76,21 @@ pub trait VecZnxSub {
B: VecZnxToRef;
}
pub trait VecZnxSubABInplace {
pub trait VecZnxSubInplace {
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
///
/// res\[res_col\] -= a\[a_col\]
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_sub_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxSubBAInplace {
pub trait VecZnxSubNegateInplace {
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
///
/// res\[res_col\] = a\[a_col\] - res\[res_col\]
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_sub_negate_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
@@ -127,8 +136,16 @@ pub trait VecZnxLshTmpBytes {
pub trait VecZnxLsh<B: Backend> {
/// Left shift by k bits all columns of `a`.
#[allow(clippy::too_many_arguments)]
fn vec_znx_lsh<R, A>(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
fn vec_znx_lsh<R, A>(
&self,
base2k: usize,
k: usize,
r: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef;
}
@@ -140,22 +157,30 @@ pub trait VecZnxRshTmpBytes {
pub trait VecZnxRsh<B: Backend> {
/// Right shift by k bits all columns of `a`.
#[allow(clippy::too_many_arguments)]
fn vec_znx_rsh<R, A>(&self, basek: usize, k: usize, r: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
fn vec_znx_rsh<R, A>(
&self,
base2k: usize,
k: usize,
r: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxLshInplace<B: Backend> {
/// Left shift by k bits all columns of `a`.
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_lsh_inplace<A>(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
pub trait VecZnxRshInplace<B: Backend> {
/// Right shift by k bits all columns of `a`.
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_rsh_inplace<A>(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
@@ -264,8 +289,8 @@ pub trait VecZnxCopy {
}
pub trait VecZnxFillUniform {
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
/// Fills the first `size` size with uniform values in \[-2^{base2k-1}, 2^{base2k-1}\]
fn vec_znx_fill_uniform<R>(&self, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut;
}
@@ -274,7 +299,7 @@ pub trait VecZnxFillUniform {
pub trait VecZnxFillNormal {
fn vec_znx_fill_normal<R>(
&self,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -290,7 +315,7 @@ pub trait VecZnxAddNormal {
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
fn vec_znx_add_normal<R>(
&self,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,

View File

@@ -30,7 +30,7 @@ pub trait VecZnxBigFromBytes<B: Backend> {
/// Add a discrete normal distribution on res.
///
/// # Arguments
/// * `basek`: base two logarithm of the bivariate representation
/// * `base2k`: base two logarithm of the bivariate representation
/// * `res`: receiver.
/// * `res_col`: column of the receiver on which the operation is performed/stored.
/// * `k`:
@@ -40,7 +40,7 @@ pub trait VecZnxBigFromBytes<B: Backend> {
pub trait VecZnxBigAddNormal<B: Backend> {
fn vec_znx_big_add_normal<R: VecZnxBigToMut<B>>(
&self,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -93,17 +93,17 @@ pub trait VecZnxBigSub<B: Backend> {
C: VecZnxBigToRef<B>;
}
pub trait VecZnxBigSubABInplace<B: Backend> {
pub trait VecZnxBigSubInplace<B: Backend> {
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
pub trait VecZnxBigSubBAInplace<B: Backend> {
pub trait VecZnxBigSubNegateInplace<B: Backend> {
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_negate_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
@@ -118,9 +118,9 @@ pub trait VecZnxBigSubSmallA<B: Backend> {
C: VecZnxBigToRef<B>;
}
pub trait VecZnxBigSubSmallAInplace<B: Backend> {
pub trait VecZnxBigSubSmallInplace<B: Backend> {
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
@@ -135,9 +135,9 @@ pub trait VecZnxBigSubSmallB<B: Backend> {
C: VecZnxToRef;
}
pub trait VecZnxBigSubSmallBInplace<B: Backend> {
pub trait VecZnxBigSubSmallNegateInplace<B: Backend> {
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_negate_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
@@ -160,12 +160,14 @@ pub trait VecZnxBigNormalizeTmpBytes {
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
}
#[allow(clippy::too_many_arguments)]
pub trait VecZnxBigNormalize<B: Backend> {
fn vec_znx_big_normalize<R, A>(
&self,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,

View File

@@ -68,15 +68,15 @@ pub trait VecZnxDftSub<B: Backend> {
D: VecZnxDftToRef<B>;
}
pub trait VecZnxDftSubABInplace<B: Backend> {
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub trait VecZnxDftSubInplace<B: Backend> {
fn vec_znx_dft_sub_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
pub trait VecZnxDftSubBAInplace<B: Backend> {
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub trait VecZnxDftSubNegateInplace<B: Backend> {
fn vec_znx_dft_sub_negate_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;

View File

@@ -12,14 +12,14 @@ pub trait ZnNormalizeTmpBytes {
pub trait ZnNormalizeInplace<B: Backend> {
/// Normalizes the selected column of `a`.
fn zn_normalize_inplace<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
fn zn_normalize_inplace<R>(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
R: ZnToMut;
}
pub trait ZnFillUniform {
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
/// Fills the first `size` size with uniform values in \[-2^{base2k-1}, 2^{base2k-1}\]
fn zn_fill_uniform<R>(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut;
}
@@ -29,7 +29,7 @@ pub trait ZnFillNormal {
fn zn_fill_normal<R>(
&self,
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -46,7 +46,7 @@ pub trait ZnAddNormal {
fn zn_add_normal<R>(
&self,
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,

View File

@@ -15,16 +15,16 @@ use crate::{
pub fn bench_svp_prepare<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpPrepare<B> + SvpPPolAlloc<B> + ModuleNew<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_prepare::{}", label);
let group_name: String = format!("svp_prepare::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(log_n: usize) -> impl FnMut()
where
Module<B>: SvpPrepare<B> + SvpPPolAlloc<B> + ModuleNew<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let module: Module<B> = Module::<B>::new(1 << log_n);
@@ -53,16 +53,16 @@ where
pub fn bench_svp_apply_dft<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_apply_dft::{}", label);
let group_name: String = format!("svp_apply_dft::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -100,16 +100,16 @@ where
pub fn bench_svp_apply_dft_to_dft<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDftToDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_apply_dft_to_dft::{}", label);
let group_name: String = format!("svp_apply_dft_to_dft::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDftToDft<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -147,16 +147,16 @@ where
pub fn bench_svp_apply_dft_to_dft_add<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDftToDftAdd<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_apply_dft_to_dft_add::{}", label);
let group_name: String = format!("svp_apply_dft_to_dft_add::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDftToDftAdd<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -194,16 +194,16 @@ where
pub fn bench_svp_apply_dft_to_dft_inplace<B>(c: &mut Criterion, label: &str)
where
Module<B>: SvpApplyDftToDftInplace<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let group_name: String = format!("svp_apply_dft_to_dft_inplace::{}", label);
let group_name: String = format!("svp_apply_dft_to_dft_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: SvpApplyDftToDftInplace<B> + SvpPPolAlloc<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
B: Backend<ScalarPrep = f64>,
B: Backend,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];

View File

@@ -8,7 +8,7 @@ use crate::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddSmall,
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace,
VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA,
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubInplace, VecZnxBigSubNegateInplace, VecZnxBigSubSmallA,
VecZnxBigSubSmallB,
},
layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig},
@@ -19,7 +19,7 @@ pub fn bench_vec_znx_big_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAdd<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add::{}", label);
let group_name: String = format!("vec_znx_big_add::{label}");
let mut group = c.benchmark_group(group_name);
@@ -65,7 +65,7 @@ pub fn bench_vec_znx_big_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAddInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add_inplace::{}", label);
let group_name: String = format!("vec_znx_big_add_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -109,7 +109,7 @@ pub fn bench_vec_znx_big_add_small<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigAddSmall<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add_small::{}", label);
let group_name: String = format!("vec_znx_big_add_small::{label}");
let mut group = c.benchmark_group(group_name);
@@ -155,7 +155,7 @@ pub fn bench_vec_znx_big_add_small_inplace<B: Backend>(c: &mut Criterion, label:
where
Module<B>: VecZnxBigAddSmallInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_add_small_inplace::{}", label);
let group_name: String = format!("vec_znx_big_add_small_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -199,7 +199,7 @@ pub fn bench_vec_znx_big_automorphism<B: Backend>(c: &mut Criterion, label: &str
where
Module<B>: VecZnxBigAutomorphism<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_automorphism::{}", label);
let group_name: String = format!("vec_znx_big_automorphism::{label}");
let mut group = c.benchmark_group(group_name);
@@ -244,7 +244,7 @@ where
Module<B>: VecZnxBigAutomorphismInplace<B> + VecZnxBigAutomorphismInplaceTmpBytes + ModuleNew<B> + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
let group_name: String = format!("vec_znx_automorphism_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -289,7 +289,7 @@ pub fn bench_vec_znx_big_negate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigNegate<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_negate::{}", label);
let group_name: String = format!("vec_znx_big_negate::{label}");
let mut group = c.benchmark_group(group_name);
@@ -332,7 +332,7 @@ pub fn bench_vec_znx_big_negate_inplace<B: Backend>(c: &mut Criterion, label: &s
where
Module<B>: VecZnxBigNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_negate_big_inplace::{}", label);
let group_name: String = format!("vec_znx_negate_big_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -374,7 +374,7 @@ where
Module<B>: VecZnxBigNormalize<B> + ModuleNew<B> + VecZnxBigNormalizeTmpBytes + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_big_normalize::{}", label);
let group_name: String = format!("vec_znx_big_normalize::{label}");
let mut group = c.benchmark_group(group_name);
@@ -389,7 +389,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -404,7 +404,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_big_normalize(basek, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_big_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
}
black_box(());
}
@@ -423,7 +423,7 @@ pub fn bench_vec_znx_big_sub<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSub<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub::{}", label);
let group_name: String = format!("vec_znx_big_sub::{label}");
let mut group = c.benchmark_group(group_name);
@@ -464,17 +464,17 @@ where
group.finish();
}
pub fn bench_vec_znx_big_sub_ab_inplace<B: Backend>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_big_sub_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubABInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
Module<B>: VecZnxBigSubInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_inplace::{}", label);
let group_name: String = format!("vec_znx_big_sub_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigSubABInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
Module<B>: VecZnxBigSubInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
@@ -492,7 +492,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_big_sub_ab_inplace(&mut c, i, &a, i);
module.vec_znx_big_sub_inplace(&mut c, i, &a, i);
}
black_box(());
}
@@ -507,17 +507,17 @@ where
group.finish();
}
pub fn bench_vec_znx_big_sub_ba_inplace<B: Backend>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_big_sub_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubBAInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
Module<B>: VecZnxBigSubNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_inplace::{}", label);
let group_name: String = format!("vec_znx_big_sub_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxBigSubBAInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
Module<B>: VecZnxBigSubNegateInplace<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let module: Module<B> = Module::<B>::new(1 << params[0]);
@@ -535,7 +535,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_big_sub_ba_inplace(&mut c, i, &a, i);
module.vec_znx_big_sub_negate_inplace(&mut c, i, &a, i);
}
black_box(());
}
@@ -554,7 +554,7 @@ pub fn bench_vec_znx_big_sub_small_a<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubSmallA<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_small_a::{}", label);
let group_name: String = format!("vec_znx_big_sub_small_a::{label}");
let mut group = c.benchmark_group(group_name);
@@ -599,7 +599,7 @@ pub fn bench_vec_znx_big_sub_small_b<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxBigSubSmallB<B> + ModuleNew<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_big_sub_small_b::{}", label);
let group_name: String = format!("vec_znx_big_sub_small_b::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -6,7 +6,7 @@ use rand::RngCore;
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc,
VecZnxDftApply, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyTmpA,
VecZnxDftApply, VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxIdftApply, VecZnxIdftApplyTmpA,
VecZnxIdftApplyTmpBytes,
},
layouts::{Backend, DataViewMut, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft},
@@ -17,7 +17,7 @@ pub fn bench_vec_znx_dft_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftAdd<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_add::{}", label);
let group_name: String = format!("vec_znx_dft_add::{label}");
let mut group = c.benchmark_group(group_name);
@@ -62,7 +62,7 @@ pub fn bench_vec_znx_dft_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftAddInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_add_inplace::{}", label);
let group_name: String = format!("vec_znx_dft_add_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -106,7 +106,7 @@ pub fn bench_vec_znx_dft_apply<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftApply<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_apply::{}", label);
let group_name: String = format!("vec_znx_dft_apply::{label}");
let mut group = c.benchmark_group(group_name);
@@ -149,7 +149,7 @@ where
Module<B>: VecZnxIdftApply<B> + ModuleNew<B> + VecZnxIdftApplyTmpBytes + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_idft_apply::{}", label);
let group_name: String = format!("vec_znx_idft_apply::{label}");
let mut group = c.benchmark_group(group_name);
@@ -194,7 +194,7 @@ pub fn bench_vec_znx_idft_apply_tmpa<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxIdftApplyTmpA<B> + ModuleNew<B> + VecZnxDftAlloc<B> + VecZnxBigAlloc<B>,
{
let group_name: String = format!("vec_znx_idft_apply_tmpa::{}", label);
let group_name: String = format!("vec_znx_idft_apply_tmpa::{label}");
let mut group = c.benchmark_group(group_name);
@@ -235,7 +235,7 @@ pub fn bench_vec_znx_dft_sub<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftSub<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_sub::{}", label);
let group_name: String = format!("vec_znx_dft_sub::{label}");
let mut group = c.benchmark_group(group_name);
@@ -276,17 +276,17 @@ where
group.finish();
}
pub fn bench_vec_znx_dft_sub_ab_inplace<B: Backend>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_dft_sub_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftSubABInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
Module<B>: VecZnxDftSubInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_sub_ab_inplace::{}", label);
let group_name: String = format!("vec_znx_dft_sub_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftSubABInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
Module<B>: VecZnxDftSubInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
@@ -305,7 +305,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_dft_sub_ab_inplace(&mut c, i, &a, i);
module.vec_znx_dft_sub_inplace(&mut c, i, &a, i);
}
black_box(());
}
@@ -320,17 +320,17 @@ where
group.finish();
}
pub fn bench_vec_znx_dft_sub_ba_inplace<B: Backend>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_dft_sub_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxDftSubBAInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
Module<B>: VecZnxDftSubNegateInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let group_name: String = format!("vec_znx_dft_sub_ba_inplace::{}", label);
let group_name: String = format!("vec_znx_dft_sub_negate_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxDftSubBAInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
Module<B>: VecZnxDftSubNegateInplace<B> + ModuleNew<B> + VecZnxDftAlloc<B>,
{
let n: usize = params[0];
let cols: usize = params[1];
@@ -349,7 +349,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_dft_sub_ba_inplace(&mut c, i, &a, i);
module.vec_znx_dft_sub_negate_inplace(&mut c, i, &a, i);
}
black_box(());
}

View File

@@ -17,7 +17,7 @@ where
Module<B>: ModuleNew<B> + VmpPMatAlloc<B> + VmpPrepare<B> + VmpPrepareTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_prepare::{}", label);
let group_name: String = format!("vmp_prepare::{label}");
let mut group = c.benchmark_group(group_name);
@@ -76,7 +76,7 @@ where
Module<B>: ModuleNew<B> + VmpApplyDftTmpBytes + VmpApplyDft<B> + VmpPMatAlloc<B> + VecZnxDftAlloc<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_apply_dft::{}", label);
let group_name: String = format!("vmp_apply_dft::{label}");
let mut group = c.benchmark_group(group_name);
@@ -137,7 +137,7 @@ where
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDft<B> + VmpApplyDftToDftTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_apply_dft_to_dft::{}", label);
let group_name: String = format!("vmp_apply_dft_to_dft::{label}");
let mut group = c.benchmark_group(group_name);
@@ -200,7 +200,7 @@ where
Module<B>: ModuleNew<B> + VecZnxDftAlloc<B> + VmpPMatAlloc<B> + VmpApplyDftToDftAdd<B> + VmpApplyDftToDftAddTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vmp_apply_dft_to_dft_add::{}", label);
let group_name: String = format!("vmp_apply_dft_to_dft_add::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -1,11 +1,11 @@
use crate::{
api::{
ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx,
TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat,
ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeMatZnx, TakeScalarZnx, TakeSlice,
TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat,
},
layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
oep::{
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeLikeImpl, TakeMatZnxImpl,
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeMatZnxImpl,
TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl,
TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl,
},
@@ -156,80 +156,3 @@ where
B::take_mat_znx_impl(self, n, rows, cols_in, cols_out, size)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, ScalarZnx<D>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, ScalarZnx<D>, Output = ScalarZnx<&'a mut [u8]>>,
D: DataRef,
{
type Output = ScalarZnx<&'a mut [u8]>;
fn take_like(&'a mut self, template: &ScalarZnx<D>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, SvpPPol<D, B>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, SvpPPol<D, B>, Output = SvpPPol<&'a mut [u8], B>>,
D: DataRef,
{
type Output = SvpPPol<&'a mut [u8], B>;
fn take_like(&'a mut self, template: &SvpPPol<D, B>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnx<D>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, VecZnx<D>, Output = VecZnx<&'a mut [u8]>>,
D: DataRef,
{
type Output = VecZnx<&'a mut [u8]>;
fn take_like(&'a mut self, template: &VecZnx<D>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxBig<D, B>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, VecZnxBig<D, B>, Output = VecZnxBig<&'a mut [u8], B>>,
D: DataRef,
{
type Output = VecZnxBig<&'a mut [u8], B>;
fn take_like(&'a mut self, template: &VecZnxBig<D, B>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, VecZnxDft<D, B>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, VecZnxDft<D, B>, Output = VecZnxDft<&'a mut [u8], B>>,
D: DataRef,
{
type Output = VecZnxDft<&'a mut [u8], B>;
fn take_like(&'a mut self, template: &VecZnxDft<D, B>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, MatZnx<D>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, MatZnx<D>, Output = MatZnx<&'a mut [u8]>>,
D: DataRef,
{
type Output = MatZnx<&'a mut [u8]>;
fn take_like(&'a mut self, template: &MatZnx<D>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}
impl<'a, B: Backend, D> TakeLike<'a, B, VmpPMat<D, B>> for Scratch<B>
where
B: TakeLikeImpl<'a, B, VmpPMat<D, B>, Output = VmpPMat<&'a mut [u8], B>>,
D: DataRef,
{
type Output = VmpPMat<&'a mut [u8], B>;
fn take_like(&'a mut self, template: &VmpPMat<D, B>) -> (Self::Output, &'a mut Self) {
B::take_like_impl(self, template)
}
}

View File

@@ -5,8 +5,8 @@ use crate::{
VecZnxLshInplace, VecZnxLshTmpBytes, VecZnxMergeRings, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne,
VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes,
VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubABInplace,
VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace,
VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
},
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
@@ -17,7 +17,7 @@ use crate::{
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
},
source::Source,
@@ -36,12 +36,21 @@ impl<B> VecZnxNormalize<B> for Module<B>
where
B: Backend + VecZnxNormalizeImpl<B>,
{
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<B>)
where
#[allow(clippy::too_many_arguments)]
fn vec_znx_normalize<R, A>(
&self,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
) where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_normalize_impl(self, basek, res, res_col, a, a_col, scratch)
B::vec_znx_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch)
}
}
@@ -49,11 +58,11 @@ impl<B> VecZnxNormalizeInplace<B> for Module<B>
where
B: Backend + VecZnxNormalizeInplaceImpl<B>,
{
fn vec_znx_normalize_inplace<A>(&self, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_normalize_inplace<A>(&self, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_normalize_inplace_impl(self, basek, a, a_col, scratch)
B::vec_znx_normalize_inplace_impl(self, base2k, a, a_col, scratch)
}
}
@@ -125,29 +134,29 @@ where
}
}
impl<B> VecZnxSubABInplace for Module<B>
impl<B> VecZnxSubInplace for Module<B>
where
B: Backend + VecZnxSubABInplaceImpl<B>,
B: Backend + VecZnxSubInplaceImpl<B>,
{
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_sub_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_sub_ab_inplace_impl(self, res, res_col, a, a_col)
B::vec_znx_sub_inplace_impl(self, res, res_col, a, a_col)
}
}
impl<B> VecZnxSubBAInplace for Module<B>
impl<B> VecZnxSubNegateInplace for Module<B>
where
B: Backend + VecZnxSubBAInplaceImpl<B>,
B: Backend + VecZnxSubNegateInplaceImpl<B>,
{
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_sub_negate_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_sub_ba_inplace_impl(self, res, res_col, a, a_col)
B::vec_znx_sub_negate_inplace_impl(self, res, res_col, a, a_col)
}
}
@@ -227,7 +236,7 @@ where
{
fn vec_znx_lsh<R, A>(
&self,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -238,7 +247,7 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_lsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
B::vec_znx_lsh_impl(self, base2k, k, res, res_col, a, a_col, scratch);
}
}
@@ -248,7 +257,7 @@ where
{
fn vec_znx_rsh<R, A>(
&self,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -259,7 +268,7 @@ where
R: VecZnxToMut,
A: VecZnxToRef,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, res, res_col, a, a_col, scratch);
B::vec_znx_rsh_impl(self, base2k, k, res, res_col, a, a_col, scratch);
}
}
@@ -267,11 +276,11 @@ impl<B> VecZnxLshInplace<B> for Module<B>
where
B: Backend + VecZnxLshInplaceImpl<B>,
{
fn vec_znx_lsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_lsh_inplace<A>(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_lsh_inplace_impl(self, basek, k, a, a_col, scratch)
B::vec_znx_lsh_inplace_impl(self, base2k, k, a, a_col, scratch)
}
}
@@ -279,11 +288,11 @@ impl<B> VecZnxRshInplace<B> for Module<B>
where
B: Backend + VecZnxRshInplaceImpl<B>,
{
fn vec_znx_rsh_inplace<A>(&self, basek: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_rsh_inplace<A>(&self, base2k: usize, k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut,
{
B::vec_znx_rsh_inplace_impl(self, basek, k, a, a_col, scratch)
B::vec_znx_rsh_inplace_impl(self, base2k, k, a, a_col, scratch)
}
}
@@ -463,11 +472,11 @@ impl<B> VecZnxFillUniform for Module<B>
where
B: Backend + VecZnxFillUniformImpl<B>,
{
fn vec_znx_fill_uniform<R>(&self, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn vec_znx_fill_uniform<R>(&self, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
B::vec_znx_fill_uniform_impl(self, basek, res, res_col, source);
B::vec_znx_fill_uniform_impl(self, base2k, res, res_col, source);
}
}
@@ -477,7 +486,7 @@ where
{
fn vec_znx_fill_normal<R>(
&self,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -487,7 +496,7 @@ where
) where
R: VecZnxToMut,
{
B::vec_znx_fill_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
B::vec_znx_fill_normal_impl(self, base2k, res, res_col, k, source, sigma, bound);
}
}
@@ -497,7 +506,7 @@ where
{
fn vec_znx_add_normal<R>(
&self,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -507,6 +516,6 @@ where
) where
R: VecZnxToMut,
{
B::vec_znx_add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
B::vec_znx_add_normal_impl(self, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -3,17 +3,17 @@ use crate::{
VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddNormal, VecZnxBigAddSmall, VecZnxBigAddSmallInplace, VecZnxBigAlloc,
VecZnxBigAllocBytes, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes,
VecZnxBigFromBytes, VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA,
VecZnxBigSubSmallAInplace, VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
VecZnxBigNormalizeTmpBytes, VecZnxBigSub, VecZnxBigSubInplace, VecZnxBigSubNegateInplace, VecZnxBigSubSmallA,
VecZnxBigSubSmallB, VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace,
},
layouts::{Backend, Module, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl,
VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl,
VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl, VecZnxBigFromSmallImpl, VecZnxBigNegateImpl,
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl,
VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl,
VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubImpl,
VecZnxBigSubInplaceImpl, VecZnxBigSubNegateInplaceImpl, VecZnxBigSubSmallAImpl, VecZnxBigSubSmallBImpl,
VecZnxBigSubSmallInplaceImpl, VecZnxBigSubSmallNegateInplaceImpl,
},
source::Source,
};
@@ -64,7 +64,7 @@ where
{
fn vec_znx_big_add_normal<R: VecZnxBigToMut<B>>(
&self,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -72,7 +72,7 @@ where
sigma: f64,
bound: f64,
) {
B::add_normal_impl(self, basek, res, res_col, k, source, sigma, bound);
B::add_normal_impl(self, base2k, res, res_col, k, source, sigma, bound);
}
}
@@ -144,29 +144,29 @@ where
}
}
impl<B> VecZnxBigSubABInplace<B> for Module<B>
impl<B> VecZnxBigSubInplace<B> for Module<B>
where
B: Backend + VecZnxBigSubABInplaceImpl<B>,
B: Backend + VecZnxBigSubInplaceImpl<B>,
{
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_sub_ab_inplace_impl(self, res, res_col, a, a_col);
B::vec_znx_big_sub_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxBigSubBAInplace<B> for Module<B>
impl<B> VecZnxBigSubNegateInplace<B> for Module<B>
where
B: Backend + VecZnxBigSubBAInplaceImpl<B>,
B: Backend + VecZnxBigSubNegateInplaceImpl<B>,
{
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_negate_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_sub_ba_inplace_impl(self, res, res_col, a, a_col);
B::vec_znx_big_sub_negate_inplace_impl(self, res, res_col, a, a_col);
}
}
@@ -184,16 +184,16 @@ where
}
}
impl<B> VecZnxBigSubSmallAInplace<B> for Module<B>
impl<B> VecZnxBigSubSmallInplace<B> for Module<B>
where
B: Backend + VecZnxBigSubSmallAInplaceImpl<B>,
B: Backend + VecZnxBigSubSmallInplaceImpl<B>,
{
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
{
B::vec_znx_big_sub_small_a_inplace_impl(self, res, res_col, a, a_col);
B::vec_znx_big_sub_small_inplace_impl(self, res, res_col, a, a_col);
}
}
@@ -211,16 +211,16 @@ where
}
}
impl<B> VecZnxBigSubSmallBInplace<B> for Module<B>
impl<B> VecZnxBigSubSmallNegateInplace<B> for Module<B>
where
B: Backend + VecZnxBigSubSmallBInplaceImpl<B>,
B: Backend + VecZnxBigSubSmallNegateInplaceImpl<B>,
{
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_big_sub_small_negate_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef,
{
B::vec_znx_big_sub_small_b_inplace_impl(self, res, res_col, a, a_col);
B::vec_znx_big_sub_small_negate_inplace_impl(self, res, res_col, a, a_col);
}
}
@@ -264,9 +264,10 @@ where
{
fn vec_znx_big_normalize<R, A>(
&self,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
@@ -274,7 +275,7 @@ where
R: VecZnxToMut,
A: VecZnxBigToRef<B>,
{
B::vec_znx_big_normalize_impl(self, basek, res, res_col, a, a_col, scratch);
B::vec_znx_big_normalize_impl(self, res_basek, res, res_col, a_basek, a, a_col, scratch);
}
}

View File

@@ -1,7 +1,7 @@
use crate::{
api::{
VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy,
VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubABInplace, VecZnxDftSubBAInplace, VecZnxDftZero, VecZnxIdftApply,
VecZnxDftFromBytes, VecZnxDftSub, VecZnxDftSubInplace, VecZnxDftSubNegateInplace, VecZnxDftZero, VecZnxIdftApply,
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
},
layouts::{
@@ -10,7 +10,7 @@ use crate::{
},
oep::{
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubImpl, VecZnxDftSubInplaceImpl, VecZnxDftSubNegateInplaceImpl,
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
},
};
@@ -143,29 +143,29 @@ where
}
}
impl<B> VecZnxDftSubABInplace<B> for Module<B>
impl<B> VecZnxDftSubInplace<B> for Module<B>
where
B: Backend + VecZnxDftSubABInplaceImpl<B>,
B: Backend + VecZnxDftSubInplaceImpl<B>,
{
fn vec_znx_dft_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_dft_sub_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
{
B::vec_znx_dft_sub_ab_inplace_impl(self, res, res_col, a, a_col);
B::vec_znx_dft_sub_inplace_impl(self, res, res_col, a, a_col);
}
}
impl<B> VecZnxDftSubBAInplace<B> for Module<B>
impl<B> VecZnxDftSubNegateInplace<B> for Module<B>
where
B: Backend + VecZnxDftSubBAInplaceImpl<B>,
B: Backend + VecZnxDftSubNegateInplaceImpl<B>,
{
fn vec_znx_dft_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
fn vec_znx_dft_sub_negate_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
{
B::vec_znx_dft_sub_ba_inplace_impl(self, res, res_col, a, a_col);
B::vec_znx_dft_sub_negate_inplace_impl(self, res, res_col, a, a_col);
}
}

View File

@@ -18,11 +18,11 @@ impl<B> ZnNormalizeInplace<B> for Module<B>
where
B: Backend + ZnNormalizeInplaceImpl<B>,
{
fn zn_normalize_inplace<A>(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn zn_normalize_inplace<A>(&self, n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: ZnToMut,
{
B::zn_normalize_inplace_impl(n, basek, a, a_col, scratch)
B::zn_normalize_inplace_impl(n, base2k, a, a_col, scratch)
}
}
@@ -30,11 +30,11 @@ impl<B> ZnFillUniform for Module<B>
where
B: Backend + ZnFillUniformImpl<B>,
{
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn zn_fill_uniform<R>(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
B::zn_fill_uniform_impl(n, basek, res, res_col, source);
B::zn_fill_uniform_impl(n, base2k, res, res_col, source);
}
}
@@ -45,7 +45,7 @@ where
fn zn_fill_normal<R>(
&self,
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -55,7 +55,7 @@ where
) where
R: ZnToMut,
{
B::zn_fill_normal_impl(n, basek, res, res_col, k, source, sigma, bound);
B::zn_fill_normal_impl(n, base2k, res, res_col, k, source, sigma, bound);
}
}
@@ -66,7 +66,7 @@ where
fn zn_add_normal<R>(
&self,
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -76,6 +76,6 @@ where
) where
R: ZnToMut,
{
B::zn_add_normal_impl(n, basek, res, res_col, k, source, sigma, bound);
B::zn_add_normal_impl(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -3,65 +3,108 @@ use rug::{Assign, Float};
use crate::{
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::znx_zero_ref,
reference::znx::{
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef, ZnxZero,
get_carry_i128, get_digit_i128, znx_zero_ref,
},
};
impl<D: DataMut> VecZnx<D> {
pub fn encode_vec_i64(&mut self, basek: usize, col: usize, k: usize, data: &[i64], log_max: usize) {
let size: usize = k.div_ceil(basek);
pub fn encode_vec_i64(&mut self, base2k: usize, col: usize, k: usize, data: &[i64]) {
let size: usize = k.div_ceil(base2k);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = self.to_mut();
assert!(
size <= a.size(),
"invalid argument k.div_ceil(basek)={} > a.size()={}",
"invalid argument k.div_ceil(base2k)={} > a.size()={}",
size,
a.size()
);
assert!(col < a.cols());
assert!(data.len() <= a.n())
assert!(data.len() == a.n())
}
let data_len: usize = data.len();
let mut a: VecZnx<&mut [u8]> = self.to_mut();
let k_rem: usize = basek - (k % basek);
let a_size: usize = a.size();
// Zeroes coefficients of the i-th column
(0..a.size()).for_each(|i| {
for i in 0..a_size {
znx_zero_ref(a.at_mut(col, i));
});
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(col, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(i, i_rev)| {
let shift: usize = i * basek;
izip!(a.at_mut(col, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
})
}
// Case where self.prec % self.k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|i| {
a.at_mut(col, i)[..data_len]
.iter_mut()
.for_each(|x| *x <<= k_rem);
})
// Copies the data on the correct limb
a.at_mut(col, size - 1).copy_from_slice(data);
let mut carry: Vec<i64> = vec![0i64; a.n()];
let k_rem: usize = (base2k - (k % base2k)) % base2k;
// Normalizes and shift if necessary.
for j in (0..size).rev() {
if j == size - 1 {
ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry);
} else if j == 0 {
ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry);
} else {
ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry);
}
}
}
pub fn encode_coeff_i64(&mut self, basek: usize, col: usize, k: usize, idx: usize, data: i64, log_max: usize) {
let size: usize = k.div_ceil(basek);
pub fn encode_vec_i128(&mut self, base2k: usize, col: usize, k: usize, data: &[i128]) {
let size: usize = k.div_ceil(base2k);
#[cfg(debug_assertions)]
{
let a: VecZnx<&mut [u8]> = self.to_mut();
assert!(
size <= a.size(),
"invalid argument k.div_ceil(base2k)={} > a.size()={}",
size,
a.size()
);
assert!(col < a.cols());
assert!(data.len() == a.n())
}
let mut a: VecZnx<&mut [u8]> = self.to_mut();
let a_size: usize = a.size();
{
let mut carry_i128: Vec<i128> = vec![0i128; a.n()];
carry_i128.copy_from_slice(data);
for j in (0..size).rev() {
for (x, a) in izip!(a.at_mut(col, j).iter_mut(), carry_i128.iter_mut()) {
let digit: i128 = get_digit_i128(base2k, *a);
let carry: i128 = get_carry_i128(base2k, *a, digit);
*x = digit as i64;
*a = carry;
}
}
}
for j in size..a_size {
ZnxRef::znx_zero(a.at_mut(col, j));
}
let mut carry: Vec<i64> = vec![0i64; a.n()];
let k_rem: usize = (base2k - (k % base2k)) % base2k;
for j in (0..size).rev() {
if j == a_size - 1 {
ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry);
} else if j == 0 {
ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry);
} else {
ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, a.at_mut(col, j), &mut carry);
}
}
}
pub fn encode_coeff_i64(&mut self, base2k: usize, col: usize, k: usize, idx: usize, data: i64) {
let size: usize = k.div_ceil(base2k);
#[cfg(debug_assertions)]
{
@@ -69,46 +112,42 @@ impl<D: DataMut> VecZnx<D> {
assert!(idx < a.n());
assert!(
size <= a.size(),
"invalid argument k.div_ceil(basek)={} > a.size()={}",
"invalid argument k.div_ceil(base2k)={} > a.size()={}",
size,
a.size()
);
assert!(col < a.cols());
}
let k_rem: usize = basek - (k % basek);
let mut a: VecZnx<&mut [u8]> = self.to_mut();
(0..a.size()).for_each(|j| a.at_mut(col, j)[idx] = 0);
let a_size = a.size();
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(col, size - 1)[idx] = data;
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
a.at_mut(col, j_rev)[idx] = (data >> (j * basek)) & mask;
})
for j in 0..a_size {
a.at_mut(col, j)[idx] = 0
}
// Case where prec % k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|j| {
a.at_mut(col, j)[idx] <<= k_rem;
})
a.at_mut(col, size - 1)[idx] = data;
let mut carry: Vec<i64> = vec![0i64; 1];
let k_rem: usize = (base2k - (k % base2k)) % base2k;
for j in (0..size).rev() {
let slice = &mut a.at_mut(col, j)[idx..idx + 1];
if j == size - 1 {
ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, slice, &mut carry);
} else if j == 0 {
ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, slice, &mut carry);
} else {
ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, slice, &mut carry);
}
}
}
}
impl<D: DataRef> VecZnx<D> {
pub fn decode_vec_i64(&self, basek: usize, col: usize, k: usize, data: &mut [i64]) {
let size: usize = k.div_ceil(basek);
pub fn decode_vec_i64(&self, base2k: usize, col: usize, k: usize, data: &mut [i64]) {
let size: usize = k.div_ceil(base2k);
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = self.to_ref();
@@ -123,26 +162,26 @@ impl<D: DataRef> VecZnx<D> {
let a: VecZnx<&[u8]> = self.to_ref();
data.copy_from_slice(a.at(col, 0));
let rem: usize = basek - (k % basek);
if k < basek {
let rem: usize = base2k - (k % base2k);
if k < base2k {
data.iter_mut().for_each(|x| *x >>= rem);
} else {
(1..size).for_each(|i| {
if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
if i == size - 1 && rem != base2k {
let k_rem: usize = (base2k - rem) % base2k;
izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(a.at(col, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << basek) + x;
*y = (*y << base2k) + x;
});
}
})
}
}
pub fn decode_coeff_i64(&self, basek: usize, col: usize, k: usize, idx: usize) -> i64 {
pub fn decode_coeff_i64(&self, base2k: usize, col: usize, k: usize, idx: usize) -> i64 {
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = self.to_ref();
@@ -151,22 +190,22 @@ impl<D: DataRef> VecZnx<D> {
}
let a: VecZnx<&[u8]> = self.to_ref();
let size: usize = k.div_ceil(basek);
let size: usize = k.div_ceil(base2k);
let mut res: i64 = 0;
let rem: usize = basek - (k % basek);
let rem: usize = base2k - (k % base2k);
(0..size).for_each(|j| {
let x: i64 = a.at(col, j)[idx];
if j == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
if j == size - 1 && rem != base2k {
let k_rem: usize = (base2k - rem) % base2k;
res = (res << k_rem) + (x >> rem);
} else {
res = (res << basek) + x;
res = (res << base2k) + x;
}
});
res
}
pub fn decode_vec_float(&self, basek: usize, col: usize, data: &mut [Float]) {
pub fn decode_vec_float(&self, base2k: usize, col: usize, data: &mut [Float]) {
#[cfg(debug_assertions)]
{
let a: VecZnx<&[u8]> = self.to_ref();
@@ -181,12 +220,12 @@ impl<D: DataRef> VecZnx<D> {
let a: VecZnx<&[u8]> = self.to_ref();
let size: usize = a.size();
let prec: u32 = (basek * size) as u32;
let prec: u32 = (base2k * size) as u32;
// 2^{basek}
let base = Float::with_val(prec, (1u64 << basek) as f64);
// 2^{base2k}
let base: Float = Float::with_val(prec, (1u64 << base2k) as f64);
// y[i] = sum x[j][i] * 2^{-basek*j}
// y[i] = sum x[j][i] * 2^{-base2k*j}
(0..size).for_each(|i| {
if i == 0 {
izip!(a.at(col, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
@@ -204,78 +243,74 @@ impl<D: DataRef> VecZnx<D> {
}
impl<D: DataMut> Zn<D> {
pub fn encode_i64(&mut self, basek: usize, k: usize, data: i64, log_max: usize) {
let size: usize = k.div_ceil(basek);
pub fn encode_i64(&mut self, base2k: usize, k: usize, data: i64) {
let size: usize = k.div_ceil(base2k);
#[cfg(debug_assertions)]
{
let a: Zn<&mut [u8]> = self.to_mut();
assert!(
size <= a.size(),
"invalid argument k.div_ceil(basek)={} > a.size()={}",
"invalid argument k.div_ceil(base2k)={} > a.size()={}",
size,
a.size()
);
}
let k_rem: usize = basek - (k % basek);
let mut a: Zn<&mut [u8]> = self.to_mut();
(0..a.size()).for_each(|j| a.at_mut(0, j)[0] = 0);
let a_size = a.size();
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(0, size - 1)[0] = data;
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
a.at_mut(0, j_rev)[0] = (data >> (j * basek)) & mask;
})
for j in 0..a_size {
a.at_mut(0, j)[0] = 0
}
// Case where prec % k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|j| {
a.at_mut(0, j)[0] <<= k_rem;
})
a.at_mut(0, size - 1)[0] = data;
let mut carry: Vec<i64> = vec![0i64; 1];
let k_rem: usize = (base2k - (k % base2k)) % base2k;
for j in (0..size).rev() {
let slice = &mut a.at_mut(0, j)[..1];
if j == size - 1 {
ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, slice, &mut carry);
} else if j == 0 {
ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, slice, &mut carry);
} else {
ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, slice, &mut carry);
}
}
}
}
impl<D: DataRef> Zn<D> {
pub fn decode_i64(&self, basek: usize, k: usize) -> i64 {
pub fn decode_i64(&self, base2k: usize, k: usize) -> i64 {
let a: Zn<&[u8]> = self.to_ref();
let size: usize = k.div_ceil(basek);
let size: usize = k.div_ceil(base2k);
let mut res: i64 = 0;
let rem: usize = basek - (k % basek);
let rem: usize = base2k - (k % base2k);
(0..size).for_each(|j| {
let x: i64 = a.at(0, j)[0];
if j == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
if j == size - 1 && rem != base2k {
let k_rem: usize = (base2k - rem) % base2k;
res = (res << k_rem) + (x >> rem);
} else {
res = (res << basek) + x;
res = (res << base2k) + x;
}
});
res
}
pub fn decode_float(&self, basek: usize) -> Float {
pub fn decode_float(&self, base2k: usize) -> Float {
let a: Zn<&[u8]> = self.to_ref();
let size: usize = a.size();
let prec: u32 = (basek * size) as u32;
let prec: u32 = (base2k * size) as u32;
// 2^{basek}
let base: Float = Float::with_val(prec, (1 << basek) as f64);
let mut res: Float = Float::with_val(prec, (1 << basek) as f64);
// 2^{base2k}
let base: Float = Float::with_val(prec, (1 << base2k) as f64);
let mut res: Float = Float::with_val(prec, (1 << base2k) as f64);
// y[i] = sum x[j][i] * 2^{-basek*j}
// y[i] = sum x[j][i] * 2^{-base2k*j}
(0..size).for_each(|i| {
if i == 0 {
res.assign(a.at(0, size - i - 1)[0]);

View File

@@ -1,7 +1,7 @@
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
@@ -54,7 +54,7 @@ impl<D: DataRef> ToOwnedDeep for MatZnx<D> {
impl<D: DataRef> fmt::Debug for MatZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
write!(f, "{self}")
}
}
@@ -211,17 +211,6 @@ impl<D: DataMut> FillUniform for MatZnx<D> {
}
}
impl<D: DataMut> Reset for MatZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.size = 0;
self.rows = 0;
self.cols_in = 0;
self.cols_out = 0;
}
}
pub type MatZnxOwned = MatZnx<Vec<u8>>;
pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>;
pub type MatZnxRef<'a> = MatZnx<&'a [u8]>;
@@ -316,9 +305,9 @@ impl<D: DataRef> fmt::Display for MatZnx<D> {
)?;
for row_i in 0..self.rows {
writeln!(f, "Row {}:", row_i)?;
writeln!(f, "Row {row_i}:")?;
for col_i in 0..self.cols_in {
writeln!(f, "cols_in {}:", col_i)?;
writeln!(f, "cols_in {col_i}:")?;
writeln!(f, "{}:", self.at(row_i, col_i))?;
}
}

View File

@@ -26,7 +26,7 @@ pub use vmp_pmat::*;
pub use zn::*;
pub use znx_base::*;
pub trait Data = PartialEq + Eq + Sized;
pub trait Data = PartialEq + Eq + Sized + Default;
pub trait DataRef = Data + AsRef<[u8]>;
pub trait DataMut = DataRef + AsMut<[u8]>;

View File

@@ -7,7 +7,7 @@ use rand_distr::{Distribution, weighted::WeightedIndex};
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, VecZnx, WriterTo,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
@@ -173,14 +173,6 @@ impl<D: DataMut> FillUniform for ScalarZnx<D> {
}
}
impl<D: DataMut> Reset for ScalarZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.cols = 0;
}
}
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
impl<D: Data> ScalarZnx<D> {

View File

@@ -7,10 +7,10 @@ use rug::{
use crate::layouts::{Backend, DataRef, VecZnx, VecZnxBig, VecZnxBigToRef, ZnxInfos};
impl<D: DataRef> VecZnx<D> {
pub fn std(&self, basek: usize, col: usize) -> f64 {
let prec: u32 = (self.size() * basek) as u32;
pub fn std(&self, base2k: usize, col: usize) -> f64 {
let prec: u32 = (self.size() * base2k) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
self.decode_vec_float(basek, col, &mut data);
self.decode_vec_float(base2k, col, &mut data);
// std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {
@@ -29,7 +29,7 @@ impl<D: DataRef> VecZnx<D> {
}
impl<D: DataRef, B: Backend + Backend<ScalarBig = i64>> VecZnxBig<D, B> {
pub fn std(&self, basek: usize, col: usize) -> f64 {
pub fn std(&self, base2k: usize, col: usize) -> f64 {
let self_ref: VecZnxBig<&[u8], B> = self.to_ref();
let znx: VecZnx<&[u8]> = VecZnx {
data: self_ref.data,
@@ -38,6 +38,6 @@ impl<D: DataRef, B: Backend + Backend<ScalarBig = i64>> VecZnxBig<D, B> {
size: self_ref.size,
max_size: self_ref.max_size,
};
znx.std(basek, col)
znx.std(base2k, col)
}
}

View File

@@ -176,7 +176,7 @@ impl<D: DataRef, B: Backend> fmt::Display for SvpPPol<D, B> {
writeln!(f, "SvpPPol(n={}, cols={})", self.n, self.cols)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
writeln!(f, "Column {col}:")?;
let coeffs = self.at(col, 0);
write!(f, "[")?;
@@ -187,7 +187,7 @@ impl<D: DataRef, B: Backend> fmt::Display for SvpPPol<D, B> {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
write!(f, "{coeff}")?;
}
if coeffs.len() > max_show {

View File

@@ -6,8 +6,8 @@ use std::{
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos,
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
@@ -25,6 +25,18 @@ pub struct VecZnx<D: Data> {
pub max_size: usize,
}
impl<D: Data + Default> Default for VecZnx<D> {
fn default() -> Self {
Self {
data: D::default(),
n: 0,
cols: 0,
size: 0,
max_size: 0,
}
}
}
impl<D: DataRef> DigestU64 for VecZnx<D> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
@@ -52,7 +64,7 @@ impl<D: DataRef> ToOwnedDeep for VecZnx<D> {
impl<D: DataRef> fmt::Debug for VecZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
write!(f, "{self}")
}
}
@@ -162,10 +174,10 @@ impl<D: DataRef> fmt::Display for VecZnx<D> {
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
writeln!(f, "Column {col}:")?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
write!(f, " Size {size}: [")?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
@@ -174,7 +186,7 @@ impl<D: DataRef> fmt::Display for VecZnx<D> {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
write!(f, "{coeff}")?;
}
if coeffs.len() > max_show {
@@ -204,16 +216,6 @@ impl<D: DataMut> FillUniform for VecZnx<D> {
}
}
impl<D: DataMut> Reset for VecZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.cols = 0;
self.size = 0;
self.max_size = 0;
}
}
pub type VecZnxOwned = VecZnx<Vec<u8>>;
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;

View File

@@ -179,10 +179,10 @@ impl<D: DataRef, B: Backend> fmt::Display for VecZnxBig<D, B> {
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
writeln!(f, "Column {col}:")?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
write!(f, " Size {size}: [")?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
@@ -191,7 +191,7 @@ impl<D: DataRef, B: Backend> fmt::Display for VecZnxBig<D, B> {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
write!(f, "{coeff}")?;
}
if coeffs.len() > max_show {

View File

@@ -199,10 +199,10 @@ impl<D: DataRef, B: Backend> fmt::Display for VecZnxDft<D, B> {
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
writeln!(f, "Column {col}:")?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
write!(f, " Size {size}: [")?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
@@ -211,7 +211,7 @@ impl<D: DataRef, B: Backend> fmt::Display for VecZnxDft<D, B> {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
write!(f, "{coeff}")?;
}
if coeffs.len() > max_show {

View File

@@ -6,8 +6,8 @@ use std::{
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, Reset, ToOwnedDeep, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos,
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
@@ -52,7 +52,7 @@ impl<D: DataRef> ToOwnedDeep for Zn<D> {
impl<D: DataRef> fmt::Debug for Zn<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
write!(f, "{self}")
}
}
@@ -162,10 +162,10 @@ impl<D: DataRef> fmt::Display for Zn<D> {
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
writeln!(f, "Column {col}:")?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
write!(f, " Size {size}: [")?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
@@ -174,7 +174,7 @@ impl<D: DataRef> fmt::Display for Zn<D> {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
write!(f, "{coeff}")?;
}
if coeffs.len() > max_show {
@@ -204,16 +204,6 @@ impl<D: DataMut> FillUniform for Zn<D> {
}
}
impl<D: DataMut> Reset for Zn<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.cols = 0;
self.size = 0;
self.max_size = 0;
}
}
pub type ZnOwned = Zn<Vec<u8>>;
pub type ZnMut<'a> = Zn<&'a mut [u8]>;
pub type ZnRef<'a> = Zn<&'a [u8]>;

View File

@@ -119,7 +119,3 @@ where
pub trait FillUniform {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source);
}
pub trait Reset {
fn reset(&mut self);
}

View File

@@ -56,15 +56,12 @@ pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
assert!(
align.is_power_of_two(),
"Alignment must be a power of two but is {}",
align
"Alignment must be a power of two but is {align}"
);
assert_eq!(
(size * size_of::<u8>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
"size={size} must be a multiple of align={align}"
);
unsafe {
let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment");
@@ -74,9 +71,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
}
assert!(
is_aligned_custom(ptr, align),
"Memory allocation at {:p} is not aligned to {} bytes",
ptr,
align
"Memory allocation at {ptr:p} is not aligned to {align} bytes"
);
// Init allocated memory to zero
std::ptr::write_bytes(ptr, 0, size);
@@ -89,16 +84,14 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert!(
align.is_power_of_two(),
"Alignment must be a power of two but is {}",
align
"Alignment must be a power of two but is {align}"
);
assert_eq!(
(size * size_of::<T>()) % align,
0,
"size*size_of::<T>()={} must be a multiple of align={}",
"size*size_of::<T>()={} must be a multiple of align={align}",
size * size_of::<T>(),
align
);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);

View File

@@ -1,8 +1,8 @@
use crate::layouts::{Backend, Module};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/module.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/module.rs) reference implementation.
/// * See [crate::api::ModuleNew] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ModuleNewImpl<B: Backend> {
fn new_impl(n: u64) -> Module<B>;

View File

@@ -1,74 +1,72 @@
use crate::layouts::{
Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos,
};
use crate::layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::ScratchOwnedAlloc] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ScratchOwnedAllocImpl<B: Backend> {
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::ScratchOwnedBorrow] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ScratchOwnedBorrowImpl<B: Backend> {
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::ScratchFromBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ScratchFromBytesImpl<B: Backend> {
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::ScratchAvailable] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ScratchAvailableImpl<B: Backend> {
fn scratch_available_impl(scratch: &Scratch<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::ScratchOwnedAlloc] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeSliceImpl<B: Backend> {
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::TakeScalarZnx] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeScalarZnxImpl<B: Backend> {
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::TakeSvpPPol] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeSvpPPolImpl<B: Backend> {
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::TakeVecZnx] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxImpl<B: Backend> {
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::TakeVecZnxSlice] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxSliceImpl<B: Backend> {
fn take_vec_znx_slice_impl(
@@ -81,8 +79,8 @@ pub unsafe trait TakeVecZnxSliceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::TakeVecZnxBig] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxBigImpl<B: Backend> {
fn take_vec_znx_big_impl(
@@ -94,8 +92,8 @@ pub unsafe trait TakeVecZnxBigImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::TakeVecZnxDft] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxDftImpl<B: Backend> {
fn take_vec_znx_dft_impl(
@@ -107,8 +105,8 @@ pub unsafe trait TakeVecZnxDftImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::TakeVecZnxDftSlice] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVecZnxDftSliceImpl<B: Backend> {
fn take_vec_znx_dft_slice_impl(
@@ -121,8 +119,8 @@ pub unsafe trait TakeVecZnxDftSliceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::TakeVmpPMat] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeVmpPMatImpl<B: Backend> {
fn take_vmp_pmat_impl(
@@ -136,8 +134,8 @@ pub unsafe trait TakeVmpPMatImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/scratch.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/scratch.rs) reference implementation.
/// * See [crate::api::TakeMatZnx] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait TakeMatZnxImpl<B: Backend> {
fn take_mat_znx_impl(
@@ -149,110 +147,3 @@ pub unsafe trait TakeMatZnxImpl<B: Backend> {
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>);
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub trait TakeLikeImpl<'a, B: Backend, T> {
type Output;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &T) -> (Self::Output, &'a mut Scratch<B>);
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VmpPMat<D, B>> for B
where
B: TakeVmpPMatImpl<B>,
D: DataRef,
{
type Output = VmpPMat<&'a mut [u8], B>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VmpPMat<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_vmp_pmat_impl(
scratch,
template.n(),
template.rows(),
template.cols_in(),
template.cols_out(),
template.size(),
)
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, MatZnx<D>> for B
where
B: TakeMatZnxImpl<B>,
D: DataRef,
{
type Output = MatZnx<&'a mut [u8]>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &MatZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_mat_znx_impl(
scratch,
template.n(),
template.rows(),
template.cols_in(),
template.cols_out(),
template.size(),
)
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxDft<D, B>> for B
where
B: TakeVecZnxDftImpl<B>,
D: DataRef,
{
type Output = VecZnxDft<&'a mut [u8], B>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnxDft<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_vec_znx_dft_impl(scratch, template.n(), template.cols(), template.size())
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnxBig<D, B>> for B
where
B: TakeVecZnxBigImpl<B>,
D: DataRef,
{
type Output = VecZnxBig<&'a mut [u8], B>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnxBig<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_vec_znx_big_impl(scratch, template.n(), template.cols(), template.size())
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, SvpPPol<D, B>> for B
where
B: TakeSvpPPolImpl<B>,
D: DataRef,
{
type Output = SvpPPol<&'a mut [u8], B>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &SvpPPol<D, B>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_svp_ppol_impl(scratch, template.n(), template.cols())
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, VecZnx<D>> for B
where
B: TakeVecZnxImpl<B>,
D: DataRef,
{
type Output = VecZnx<&'a mut [u8]>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &VecZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_vec_znx_impl(scratch, template.n(), template.cols(), template.size())
}
}
impl<'a, B: Backend, D> TakeLikeImpl<'a, B, ScalarZnx<D>> for B
where
B: TakeScalarZnxImpl<B>,
D: DataRef,
{
type Output = ScalarZnx<&'a mut [u8]>;
fn take_like_impl(scratch: &'a mut Scratch<B>, template: &ScalarZnx<D>) -> (Self::Output, &'a mut Scratch<B>) {
B::take_scalar_znx_impl(scratch, template.n(), template.cols())
}
}

View File

@@ -3,32 +3,32 @@ use crate::layouts::{
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation.
/// * See [crate::api::SvpPPolFromBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpPPolFromBytesImpl<B: Backend> {
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation.
/// * See [crate::api::SvpPPolAlloc] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpPPolAllocImpl<B: Backend> {
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation.
/// * See [crate::api::SvpPPolAllocBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpPPolAllocBytesImpl<B: Backend> {
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation.
/// * See [crate::api::SvpPrepare] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpPrepareImpl<B: Backend> {
fn svp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
@@ -38,8 +38,8 @@ pub unsafe trait SvpPrepareImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation.
/// * See [crate::api::SvpApplyDft] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyDftImpl<B: Backend> {
fn svp_apply_dft_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
@@ -50,8 +50,8 @@ pub unsafe trait SvpApplyDftImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation.
/// * See [crate::api::SvpApplyDftToDft] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyDftToDftImpl<B: Backend> {
fn svp_apply_dft_to_dft_impl<R, A, C>(
@@ -69,8 +69,8 @@ pub unsafe trait SvpApplyDftToDftImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation.
/// * See [crate::api::SvpApplyDftToDftAdd] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyDftToDftAddImpl<B: Backend> {
fn svp_apply_dft_to_dft_add_impl<R, A, C>(
@@ -88,8 +88,8 @@ pub unsafe trait SvpApplyDftToDftAddImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/svp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/svp.rs) reference implementation.
/// * See [crate::api::SvpApplyDftToDftInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait SvpApplyDftToDftInplaceImpl: Backend {
fn svp_apply_dft_to_dft_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)

View File

@@ -4,7 +4,7 @@ use crate::{
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNormalizeTmpBytesImpl<B: Backend> {
@@ -12,15 +12,17 @@ pub unsafe trait VecZnxNormalizeTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxNormalize] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNormalizeImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
fn vec_znx_normalize_impl<R, A>(
module: &Module<B>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
@@ -30,17 +32,17 @@ pub unsafe trait VecZnxNormalizeImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L212) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxNormalizeInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNormalizeInplaceImpl<B: Backend> {
fn vec_znx_normalize_inplace_impl<A>(module: &Module<B>, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
fn vec_znx_normalize_inplace_impl<A>(module: &Module<B>, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxAdd] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddImpl<B: Backend> {
@@ -52,7 +54,7 @@ pub unsafe trait VecZnxAddImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxAddInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddInplaceImpl<B: Backend> {
@@ -63,7 +65,7 @@ pub unsafe trait VecZnxAddInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxAddScalar] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddScalarImpl<D: Backend> {
@@ -85,7 +87,7 @@ pub unsafe trait VecZnxAddScalarImpl<D: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_add_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L86) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxAddScalarInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddScalarInplaceImpl<B: Backend> {
@@ -102,7 +104,7 @@ pub unsafe trait VecZnxAddScalarInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxSub] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubImpl<B: Backend> {
@@ -114,29 +116,29 @@ pub unsafe trait VecZnxSubImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
/// * See [crate::api::VecZnxSubABInplace] for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxSubInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubABInplaceImpl<B: Backend> {
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub unsafe trait VecZnxSubInplaceImpl<B: Backend> {
fn vec_znx_sub_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
/// * See [crate::api::VecZnxSubBAInplace] for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxSubNegateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubBAInplaceImpl<B: Backend> {
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub unsafe trait VecZnxSubNegateInplaceImpl<B: Backend> {
fn vec_znx_sub_negate_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxAddScalar] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubScalarImpl<D: Backend> {
@@ -158,7 +160,7 @@ pub unsafe trait VecZnxSubScalarImpl<D: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_sub_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L125) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxSubScalarInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSubScalarInplaceImpl<B: Backend> {
@@ -175,7 +177,7 @@ pub unsafe trait VecZnxSubScalarInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxNegate] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNegateImpl<B: Backend> {
@@ -186,7 +188,7 @@ pub unsafe trait VecZnxNegateImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_negate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L322C13-L322C31) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxNegateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNegateInplaceImpl<B: Backend> {
@@ -196,7 +198,7 @@ pub unsafe trait VecZnxNegateInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_tmp_bytes] for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxRshTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRshTmpBytesImpl<B: Backend> {
@@ -204,14 +206,14 @@ pub unsafe trait VecZnxRshTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::reference::vec_znx::shift::vec_znx_rsh_inplace] for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxRsh] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRshImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
fn vec_znx_rsh_inplace_impl<R, A>(
fn vec_znx_rsh_impl<R, A>(
module: &Module<B>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -224,7 +226,7 @@ pub unsafe trait VecZnxRshImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_tmp_bytes] for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxLshTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxLshTmpBytesImpl<B: Backend> {
@@ -232,14 +234,14 @@ pub unsafe trait VecZnxLshTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::reference::vec_znx::shift::vec_znx_lsh_inplace] for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxLsh] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxLshImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
fn vec_znx_lsh_inplace_impl<R, A>(
fn vec_znx_lsh_impl<R, A>(
module: &Module<B>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -252,13 +254,13 @@ pub unsafe trait VecZnxLshImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_rsh_inplace_ref] for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxRshInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRshInplaceImpl<B: Backend> {
fn vec_znx_rsh_inplace_impl<R>(
module: &Module<B>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -268,13 +270,13 @@ pub unsafe trait VecZnxRshInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_lsh_inplace_ref] for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxLshInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxLshInplaceImpl<B: Backend> {
fn vec_znx_lsh_inplace_impl<R>(
module: &Module<B>,
basek: usize,
base2k: usize,
k: usize,
res: &mut R,
res_col: usize,
@@ -284,7 +286,7 @@ pub unsafe trait VecZnxLshInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxRotate] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRotateImpl<B: Backend> {
@@ -295,7 +297,7 @@ pub unsafe trait VecZnxRotateImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO;
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxRotateInplaceTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRotateInplaceTmpBytesImpl<B: Backend> {
@@ -303,7 +305,7 @@ pub unsafe trait VecZnxRotateInplaceTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_rotate_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L164) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxRotateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxRotateInplaceImpl<B: Backend> {
@@ -313,7 +315,7 @@ pub unsafe trait VecZnxRotateInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxAutomorphism] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAutomorphismImpl<B: Backend> {
@@ -324,7 +326,7 @@ pub unsafe trait VecZnxAutomorphismImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO;
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxAutomorphismInplaceTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAutomorphismInplaceTmpBytesImpl<B: Backend> {
@@ -332,7 +334,7 @@ pub unsafe trait VecZnxAutomorphismInplaceTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_automorphism_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L188) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxAutomorphismInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAutomorphismInplaceImpl<B: Backend> {
@@ -342,7 +344,7 @@ pub unsafe trait VecZnxAutomorphismInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxMulXpMinusOne] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMulXpMinusOneImpl<B: Backend> {
@@ -353,7 +355,7 @@ pub unsafe trait VecZnxMulXpMinusOneImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO;
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxMulXpMinusOneInplaceTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMulXpMinusOneInplaceTmpBytesImpl<B: Backend> {
@@ -361,7 +363,7 @@ pub unsafe trait VecZnxMulXpMinusOneInplaceTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_mul_xp_minus_one_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/7160f588da49712a042931ea247b4259b95cefcc/spqlios/arithmetic/vec_znx.c#L200C13-L200C41) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxMulXpMinusOneInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMulXpMinusOneInplaceImpl<B: Backend> {
@@ -376,7 +378,7 @@ pub unsafe trait VecZnxMulXpMinusOneInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO;
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxSplitRingTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSplitRingTmpBytesImpl<B: Backend> {
@@ -401,7 +403,7 @@ pub unsafe trait VecZnxSplitRingImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO;
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxMergeRingsTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxMergeRingsTmpBytesImpl<B: Backend> {
@@ -426,7 +428,7 @@ pub unsafe trait VecZnxMergeRingsImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_switch_degree_ref] for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxSwithcDegree] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxSwitchRingImpl<B: Backend> {
@@ -440,7 +442,7 @@ pub unsafe trait VecZnxSwitchRingImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::cpu_spqlios::vec_znx::vec_znx_copy_ref] for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxCopy] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxCopyImpl<B: Backend> {
@@ -451,22 +453,24 @@ pub unsafe trait VecZnxCopyImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxFillUniform] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxFillUniformImpl<B: Backend> {
fn vec_znx_fill_uniform_impl<R>(module: &Module<B>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn vec_znx_fill_uniform_impl<R>(module: &Module<B>, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxFillNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxFillNormalImpl<B: Backend> {
fn vec_znx_fill_normal_impl<R>(
module: &Module<B>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -479,12 +483,13 @@ pub unsafe trait VecZnxFillNormalImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxAddNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAddNormalImpl<B: Backend> {
fn vec_znx_add_normal_impl<R>(
module: &Module<B>,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,

View File

@@ -4,8 +4,8 @@ use crate::{
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigFromSmall] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigFromSmallImpl<B: Backend> {
fn vec_znx_big_from_small_impl<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
@@ -15,24 +15,24 @@ pub unsafe trait VecZnxBigFromSmallImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAlloc] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAllocImpl<B: Backend> {
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigFromBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigFromBytesImpl<B: Backend> {
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAllocBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAllocBytesImpl<B: Backend> {
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
@@ -40,13 +40,13 @@ pub unsafe trait VecZnxBigAllocBytesImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAddNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddNormalImpl<B: Backend> {
fn add_normal_impl<R: VecZnxBigToMut<B>>(
module: &Module<B>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -57,8 +57,8 @@ pub unsafe trait VecZnxBigAddNormalImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAdd] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddImpl<B: Backend> {
fn vec_znx_big_add_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
@@ -69,8 +69,8 @@ pub unsafe trait VecZnxBigAddImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAddInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddInplaceImpl<B: Backend> {
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
@@ -80,8 +80,8 @@ pub unsafe trait VecZnxBigAddInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAddSmall] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddSmallImpl<B: Backend> {
fn vec_znx_big_add_small_impl<R, A, C>(
@@ -99,8 +99,8 @@ pub unsafe trait VecZnxBigAddSmallImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAddSmallInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAddSmallInplaceImpl<B: Backend> {
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
@@ -110,8 +110,8 @@ pub unsafe trait VecZnxBigAddSmallInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigSub] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubImpl<B: Backend> {
fn vec_znx_big_sub_impl<R, A, C>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
@@ -122,30 +122,30 @@ pub unsafe trait VecZnxBigSubImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigSubInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubABInplaceImpl<B: Backend> {
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub unsafe trait VecZnxBigSubInplaceImpl<B: Backend> {
fn vec_znx_big_sub_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigSubNegateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubBAInplaceImpl<B: Backend> {
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub unsafe trait VecZnxBigSubNegateInplaceImpl<B: Backend> {
fn vec_znx_big_sub_negate_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxBigToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigSubSmallA] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubSmallAImpl<B: Backend> {
fn vec_znx_big_sub_small_a_impl<R, A, C>(
@@ -163,19 +163,19 @@ pub unsafe trait VecZnxBigSubSmallAImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigSubSmallInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubSmallAInplaceImpl<B: Backend> {
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub unsafe trait VecZnxBigSubSmallInplaceImpl<B: Backend> {
fn vec_znx_big_sub_small_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigSubSmallB] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubSmallBImpl<B: Backend> {
fn vec_znx_big_sub_small_b_impl<R, A, C>(
@@ -193,19 +193,19 @@ pub unsafe trait VecZnxBigSubSmallBImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigSubSmallNegateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigSubSmallBInplaceImpl<B: Backend> {
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub unsafe trait VecZnxBigSubSmallNegateInplaceImpl<B: Backend> {
fn vec_znx_big_sub_small_negate_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxToRef;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigNegate] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigNegateImpl<B: Backend> {
fn vec_znx_big_negate_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
@@ -215,8 +215,8 @@ pub unsafe trait VecZnxBigNegateImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigNegateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigNegateInplaceImpl<B: Backend> {
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<B>, a: &mut A, a_col: usize)
@@ -225,23 +225,25 @@ pub unsafe trait VecZnxBigNegateInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigNormalizeTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigNormalizeTmpBytesImpl<B: Backend> {
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<B>) -> usize;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigNormalize] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigNormalizeImpl<B: Backend> {
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<B>,
basek: usize,
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch<B>,
@@ -251,8 +253,8 @@ pub unsafe trait VecZnxBigNormalizeImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAutomorphism] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAutomorphismImpl<B: Backend> {
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<B>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
@@ -262,16 +264,16 @@ pub unsafe trait VecZnxBigAutomorphismImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAutomorphismInplaceTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAutomorphismInplaceTmpBytesImpl<B: Backend> {
fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs) reference implementation.
/// * See [crate::api::VecZnxBigAutomorphismInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigAutomorphismInplaceImpl<B: Backend> {
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<B>, k: i64, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)

View File

@@ -4,24 +4,24 @@ use crate::layouts::{
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftAlloc] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftAllocImpl<B: Backend> {
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftFromBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftFromBytesImpl<B: Backend> {
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftApply] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftApplyImpl<B: Backend> {
fn vec_znx_dft_apply_impl<R, A>(
@@ -38,24 +38,24 @@ pub unsafe trait VecZnxDftApplyImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftAllocBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxIdftApplyTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxIdftApplyTmpBytesImpl<B: Backend> {
fn vec_znx_idft_apply_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxIdftApply] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxIdftApplyImpl<B: Backend> {
fn vec_znx_idft_apply_impl<R, A>(
@@ -71,8 +71,8 @@ pub unsafe trait VecZnxIdftApplyImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxIdftApplyTmpA] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxIdftApplyTmpAImpl<B: Backend> {
fn vec_znx_idft_apply_tmpa_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
@@ -82,8 +82,8 @@ pub unsafe trait VecZnxIdftApplyTmpAImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxIdftApplyConsume] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxIdftApplyConsumeImpl<B: Backend> {
fn vec_znx_idft_apply_consume_impl<D: Data>(module: &Module<B>, a: VecZnxDft<D, B>) -> VecZnxBig<D, B>
@@ -92,8 +92,8 @@ pub unsafe trait VecZnxIdftApplyConsumeImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftAdd] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftAddImpl<B: Backend> {
fn vec_znx_dft_add_impl<R, A, D>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
@@ -104,8 +104,8 @@ pub unsafe trait VecZnxDftAddImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftAddInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftAddInplaceImpl<B: Backend> {
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
@@ -115,8 +115,8 @@ pub unsafe trait VecZnxDftAddInplaceImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftSub] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftSubImpl<B: Backend> {
fn vec_znx_dft_sub_impl<R, A, D>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
@@ -127,30 +127,30 @@ pub unsafe trait VecZnxDftSubImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftSubInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftSubABInplaceImpl<B: Backend> {
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub unsafe trait VecZnxDftSubInplaceImpl<B: Backend> {
fn vec_znx_dft_sub_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftSubNegateInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftSubBAInplaceImpl<B: Backend> {
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
pub unsafe trait VecZnxDftSubNegateInplaceImpl<B: Backend> {
fn vec_znx_dft_sub_negate_inplace_impl<R, A>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftCopy] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftCopyImpl<B: Backend> {
fn vec_znx_dft_copy_impl<R, A>(
@@ -167,8 +167,8 @@ pub unsafe trait VecZnxDftCopyImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See the [poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs) reference implementation.
/// * See [crate::api::VecZnxDftZero] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftZeroImpl<B: Backend> {
fn vec_znx_dft_zero_impl<R>(module: &Module<B>, res: &mut R)

View File

@@ -3,24 +3,24 @@ use crate::layouts::{
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpPMatAlloc] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPMatAllocImpl<B: Backend> {
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpPMatAllocBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPMatAllocBytesImpl<B: Backend> {
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpPMatFromBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPMatFromBytesImpl<B: Backend> {
fn vmp_pmat_from_bytes_impl(
@@ -34,16 +34,16 @@ pub unsafe trait VmpPMatFromBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpPrepareTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
fn vmp_prepare_tmp_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpPrepare] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPrepareImpl<B: Backend> {
fn vmp_prepare_impl<R, A>(module: &Module<B>, res: &mut R, a: &A, scratch: &mut Scratch<B>)
@@ -54,8 +54,8 @@ pub unsafe trait VmpPrepareImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpApplyDftTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyDftTmpBytesImpl<B: Backend> {
fn vmp_apply_dft_tmp_bytes_impl(
@@ -70,8 +70,8 @@ pub unsafe trait VmpApplyDftTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpApplyDft] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyDftImpl<B: Backend> {
fn vmp_apply_dft_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
@@ -83,8 +83,8 @@ pub unsafe trait VmpApplyDftImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpApplyDftToDftTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyDftToDftTmpBytesImpl<B: Backend> {
fn vmp_apply_dft_to_dft_tmp_bytes_impl(
@@ -99,8 +99,8 @@ pub unsafe trait VmpApplyDftToDftTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpApplyDftToDft] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyDftToDftImpl<B: Backend> {
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<B>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<B>)
@@ -112,8 +112,8 @@ pub unsafe trait VmpApplyDftToDftImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpApplyDftToDftAddTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyDftToDftAddTmpBytesImpl<B: Backend> {
fn vmp_apply_dft_to_dft_add_tmp_bytes_impl(
@@ -128,8 +128,8 @@ pub unsafe trait VmpApplyDftToDftAddTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference code.
/// * See TODO for corresponding public API.
/// * See [poulpy-backend/src/cpu_fft64_ref/vmp.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vmp.rs) for reference implementation.
/// * See [crate::api::VmpApplyDftToDftAdd] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpApplyDftToDftAddImpl<B: Backend> {
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.

View File

@@ -4,7 +4,7 @@ use crate::{
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnNormalizeTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnNormalizeTmpBytesImpl<B: Backend> {
@@ -12,32 +12,34 @@ pub unsafe trait ZnNormalizeTmpBytesImpl<B: Backend> {
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [zn_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/zn64.c#L9) for reference code.
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnNormalizeInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnNormalizeInplaceImpl<B: Backend> {
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
R: ZnToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnFillUniform] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnFillUniformImpl<B: Backend> {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnFillNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnFillNormalImpl<B: Backend> {
fn zn_fill_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -50,12 +52,13 @@ pub unsafe trait ZnFillNormalImpl<B: Backend> {
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnAddNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnAddNormalImpl<B: Backend> {
fn zn_add_normal_impl<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,

View File

@@ -37,7 +37,7 @@ pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
}
#[inline(always)]
pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
pub fn reim_sub_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
@@ -49,7 +49,7 @@ pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
}
#[inline(always)]
pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) {
pub fn reim_sub_negate_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());

View File

@@ -91,12 +91,12 @@ pub trait ReimSub {
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimSubABInplace {
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]);
pub trait ReimSubInplace {
fn reim_sub_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimSubBAInplace {
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]);
pub trait ReimSubNegateInplace {
fn reim_sub_negate_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimNegate {

View File

@@ -22,7 +22,7 @@ pub struct ReimFFTTable<R: Float + FloatConst + Debug> {
impl<R: Float + FloatConst + Debug + 'static> ReimFFTTable<R> {
pub fn new(m: usize) -> Self {
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
assert!(m & (m - 1) == 0, "m must be a power of two but is {m}");
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
let quarter: R = R::from(1. / 4.).unwrap();

View File

@@ -22,7 +22,7 @@ pub struct ReimIFFTTable<R: Float + FloatConst + Debug> {
impl<R: Float + FloatConst + Debug> ReimIFFTTable<R> {
pub fn new(m: usize) -> Self {
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
assert!(m & (m - 1) == 0, "m must be a power of two but is {m}");
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
let quarter: R = R::exp2(R::from(-2).unwrap());

View File

@@ -9,12 +9,13 @@ use crate::{
reference::{
vec_znx::{
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace,
},
znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref,
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNegate,
ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero,
znx_add_normal_f64_ref,
},
},
source::Source,
@@ -230,20 +231,32 @@ where
}
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
2 * n * size_of::<i64>()
}
pub fn vec_znx_big_normalize<R, A, BE>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
pub fn vec_znx_big_normalize<R, A, BE>(
res_basek: usize,
res: &mut R,
res_col: usize,
a_basek: usize,
a: &A,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxBigToRef<BE>,
BE: Backend<ScalarBig = i64>
+ ZnxZero
+ ZnxCopy
+ ZnxAddInplace
+ ZnxMulPowerOfTwoInplace
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxZero,
+ ZnxExtractDigitAddMul
+ ZnxNormalizeDigit,
{
let a: VecZnxBig<&[u8], _> = a.to_ref();
let a_vznx: VecZnx<&[u8]> = VecZnx {
@@ -254,11 +267,11 @@ where
max_size: a.max_size,
};
vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry);
vec_znx_normalize::<_, _, BE>(res_basek, res, res_col, a_basek, &a_vznx, a_col, carry);
}
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -275,8 +288,8 @@ pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
@@ -291,7 +304,7 @@ where
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
{
let n: usize = module.n();
let basek: usize = 17;
let base2k: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
@@ -303,15 +316,15 @@ where
let sqrt2: f64 = SQRT_2;
(0..cols).for_each(|col_i| {
let mut a: VecZnxBig<Vec<u8>, B> = VecZnxBig::alloc(n, cols, size);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
let std: f64 = a.std(base2k, col_i) * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}",
@@ -363,9 +376,9 @@ where
}
/// R <- A - B
pub fn vec_znx_big_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_big_sub_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
BE: Backend<ScalarBig = i64> + ZnxSubInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
@@ -388,13 +401,13 @@ where
max_size: a.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- B - A
pub fn vec_znx_big_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_big_sub_negate_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
BE: Backend<ScalarBig = i64> + ZnxSubNegateInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
@@ -417,7 +430,7 @@ where
max_size: a.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- A - B
@@ -483,7 +496,7 @@ where
/// R <- R - A
pub fn vec_znx_big_sub_small_a_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
BE: Backend<ScalarBig = i64> + ZnxSubInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
@@ -497,13 +510,13 @@ where
max_size: res.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
vec_znx_sub_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}
/// R <- A - R
pub fn vec_znx_big_sub_small_b_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
BE: Backend<ScalarBig = i64> + ZnxSubNegateInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
@@ -517,5 +530,5 @@ where
max_size: res.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
vec_znx_sub_negate_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}

View File

@@ -8,7 +8,7 @@ use crate::{
reference::{
fft64::reim::{
ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate,
ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
},
znx::ZnxZero,
},
@@ -308,9 +308,9 @@ where
}
}
pub fn vec_znx_dft_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_dft_sub_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSubABInplace,
BE: Backend<ScalarPrep = f64> + ReimSubInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
@@ -328,13 +328,13 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
BE::reim_sub_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_dft_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_dft_sub_negate_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSubBAInplace + ReimNegateInplace,
BE: Backend<ScalarPrep = f64> + ReimSubNegateInplace + ReimNegateInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
@@ -352,7 +352,7 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
BE::reim_sub_negate_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in sum_size..res_size {

View File

@@ -91,7 +91,7 @@ pub fn bench_vec_znx_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAdd + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add::{}", label);
let group_name: String = format!("vec_znx_add::{label}");
let mut group = c.benchmark_group(group_name);
@@ -136,7 +136,7 @@ pub fn bench_vec_znx_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAddInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add_inplace::{}", label);
let group_name: String = format!("vec_znx_add_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -18,12 +18,7 @@ where
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}");
}
for j in 0..min_size {

View File

@@ -63,7 +63,7 @@ pub fn bench_vec_znx_automorphism<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_automorphism::{}", label);
let group_name: String = format!("vec_znx_automorphism::{label}");
let mut group = c.benchmark_group(group_name);
@@ -108,7 +108,7 @@ where
Module<B>: VecZnxAutomorphismInplace<B> + VecZnxAutomorphismInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
let group_name: String = format!("vec_znx_automorphism_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -9,8 +9,8 @@ use crate::{
},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace},
znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
vec_znx::{vec_znx_rotate, vec_znx_sub_inplace},
znx::{ZnxNegate, ZnxRotate, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero},
},
source::Source,
};
@@ -23,16 +23,16 @@ pub fn vec_znx_mul_xp_minus_one<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usiz
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace,
ZNXARI: ZnxRotate + ZnxZero + ZnxSubInplace,
{
vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col);
vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
vec_znx_sub_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
}
pub fn vec_znx_mul_xp_minus_one_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace,
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubNegateInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
@@ -41,7 +41,7 @@ where
}
for j in 0..res.size() {
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp);
ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), tmp);
}
}
@@ -49,7 +49,7 @@ pub fn bench_vec_znx_mul_xp_minus_one<B: Backend>(c: &mut Criterion, label: &str
where
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label);
let group_name: String = format!("vec_znx_mul_xp_minus_one::{label}");
let mut group = c.benchmark_group(group_name);
@@ -94,7 +94,7 @@ where
Module<B>: VecZnxMulXpMinusOneInplace<B> + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label);
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -49,7 +49,7 @@ pub fn bench_vec_znx_negate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate::{}", label);
let group_name: String = format!("vec_znx_negate::{label}");
let mut group = c.benchmark_group(group_name);
@@ -93,7 +93,7 @@ pub fn bench_vec_znx_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate_inplace::{}", label);
let group_name: String = format!("vec_znx_negate_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -6,71 +6,204 @@ use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{
ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxZero,
ZnxAddInplace, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxZero,
},
source::Source,
};
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
2 * n * size_of::<i64>()
}
pub fn vec_znx_normalize<R, A, ZNXARI>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
pub fn vec_znx_normalize<R, A, ZNXARI>(
res_base2k: usize,
res: &mut R,
res_col: usize,
a_base2k: usize,
a: &A,
a_col: usize,
carry: &mut [i64],
) where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxAddInplace
+ ZnxMulPowerOfTwoInplace
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep,
+ ZnxNormalizeFirstStep
+ ZnxExtractDigitAddMul
+ ZnxNormalizeDigit,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= res.n());
assert!(carry.len() >= 2 * res.n());
assert_eq!(res.n(), a.n());
}
let n: usize = res.n();
let res_size: usize = res.size();
let a_size = a.size();
let a_size: usize = a.size();
if a_size > res_size {
for j in (res_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry);
if res_base2k == a_base2k {
if a_size > res_size {
for j in (res_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
}
}
for j in (1..res_size).rev() {
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
} else {
for j in (0..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
}
for j in a_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
for j in (1..res_size).rev() {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
} else {
for j in (0..a_size).rev() {
let (a_norm, carry) = carry.split_at_mut(n);
// Relevant limbs of res
let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size);
// Relevant limbs of a
let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size);
// Get carry for limbs of a that have higher precision than res
for j in (a_min_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
}
}
for j in a_size..res_size {
if a_min_size == a_size {
ZNXARI::znx_zero(carry);
}
// Maximum relevant precision of a
let a_prec: usize = a_min_size * a_base2k;
// Maximum relevant precision of res
let res_prec: usize = res_min_size * res_base2k;
// Res limb index
let mut res_idx: usize = res_min_size - 1;
// Trackers: wow much of res is left to be populated
// for the current limb.
let mut res_left: usize = res_base2k;
for j in (0..a_min_size).rev() {
// Trackers: wow much of a_norm is left to
// be flushed on res.
let mut a_left: usize = a_base2k;
// Normalizes the j-th limb of a and store the results into a_norm.
// This step is required to avoid overflow in the next step,
// which assumes that |a| is bounded by 2^{a_base2k -1}.
if j != 0 {
ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
}
// In the first iteration we need to match the precision of the input/output.
// If a_min_size * a_base2k > res_min_size * res_base2k
// then divround a_norm by the difference of precision and
// acts like if a_norm has already been partially consummed.
// Else acts like if res has been already populated
// by the difference.
if j == a_min_size - 1 {
if a_prec > res_prec {
ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm);
a_left -= a_prec - res_prec;
} else if res_prec > a_prec {
res_left -= res_prec - a_prec;
}
}
// Flushes a into res
loop {
// Selects the maximum amount of a that can be flushed
let a_take: usize = a_base2k.min(a_left).min(res_left);
// Output limb
let res_slice: &mut [i64] = res.at_mut(res_col, res_idx);
// Scaling of the value to flush
let lsh: usize = res_base2k - res_left;
// Extract the bits to flush on the output and updates
// a_norm accordingly.
ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm);
// Updates the trackers
a_left -= a_take;
res_left -= a_take;
// If the current limb of res is full,
// then normalizes this limb and adds
// the carry on a_norm.
if res_left == 0 {
// Updates tracker
res_left += res_base2k;
// Normalizes res and propagates the carry on a.
ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm);
// If we reached the last limb of res breaks,
// but we might rerun the above loop if the
// base2k of a is much smaller than the base2k
// of res.
if res_idx == 0 {
ZNXARI::znx_add_inplace(carry, a_norm);
break;
}
// Else updates the limb index of res.
res_idx -= 1
}
// If a_norm is exhausted, breaks the loop.
if a_left == 0 {
ZNXARI::znx_add_inplace(carry, a_norm);
break;
}
}
}
for j in res_min_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
{
@@ -85,11 +218,11 @@ where
for j in (0..res_size).rev() {
if j == res_size - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_first_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
}
}
}
@@ -99,7 +232,7 @@ where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize::{}", label);
let group_name: String = format!("vec_znx_normalize::{label}");
let mut group = c.benchmark_group(group_name);
@@ -114,7 +247,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -129,7 +262,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
}
black_box(());
}
@@ -149,7 +282,7 @@ where
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize_inplace::{}", label);
let group_name: String = format!("vec_znx_normalize_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -164,7 +297,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -177,7 +310,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow());
module.vec_znx_normalize_inplace(base2k, &mut a, i, scratch.borrow());
}
black_box(());
}
@@ -191,3 +324,83 @@ where
group.finish();
}
#[test]
fn test_vec_znx_normalize_conv() {
let n: usize = 8;
let mut carry: Vec<i64> = vec![0i64; 2 * n];
use crate::reference::znx::ZnxRef;
use rug::ops::SubAssignRound;
use rug::{Float, float::Round};
let mut source: Source = Source::new([1u8; 32]);
let prec: usize = 128;
let mut data: Vec<i128> = vec![0i128; n];
data.iter_mut().for_each(|x| *x = source.next_i128());
for start_base2k in 1..50 {
for end_base2k in 1..50 {
let end_size: usize = prec.div_ceil(end_base2k);
let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
want.encode_vec_i128(end_base2k, 0, prec, &data);
vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry);
// Creates a temporary poly where encoding is in start_base2k
let mut tmp: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k));
tmp.encode_vec_i128(start_base2k, 0, prec, &data);
vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry);
let mut data_tmp: Vec<Float> = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect();
tmp.decode_vec_float(start_base2k, 0, &mut data_tmp);
let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry);
let out_prec: u32 = (end_size * end_base2k) as u32;
let mut data_want: Vec<Float> = (0..n)
.map(|_| Float::with_val(out_prec as u32, 0))
.collect();
let mut data_res: Vec<Float> = (0..n)
.map(|_| Float::with_val(out_prec as u32, 0))
.collect();
have.decode_vec_float(end_base2k, 0, &mut data_want);
want.decode_vec_float(end_base2k, 0, &mut data_res);
for i in 0..n {
let mut err: Float = data_want[i].clone();
err.sub_assign_round(&data_res[i], Round::Nearest);
err = err.abs();
// println!(
// "want: {} have: {} tmp: {} (want-have): {}",
// data_want[i].to_f64(),
// data_res[i].to_f64(),
// data_tmp[i].to_f64(),
// err.to_f64()
// );
let err_log2: f64 = err
.clone()
.max(&Float::with_val(prec as u32, 1e-60))
.log2()
.to_f64();
assert!(
err_log2 <= -(out_prec as f64) + 1.,
"{} {}",
err_log2,
-(out_prec as f64) + 1.
)
}
}
}
}

View File

@@ -61,7 +61,7 @@ pub fn bench_vec_znx_rotate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRotate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_rotate::{}", label);
let group_name: String = format!("vec_znx_rotate::{label}");
let mut group = c.benchmark_group(group_name);
@@ -106,7 +106,7 @@ where
Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rotate_inplace::{}", label);
let group_name: String = format!("vec_znx_rotate_inplace::{label}");
let mut group = c.benchmark_group(group_name);

View File

@@ -4,18 +4,18 @@ use crate::{
source::Source,
};
pub fn vec_znx_fill_uniform_ref<R>(basek: usize, res: &mut R, res_col: usize, source: &mut Source)
pub fn vec_znx_fill_uniform_ref<R>(base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source)
znx_fill_uniform_ref(base2k, res.at_mut(res_col, j), source)
}
}
pub fn vec_znx_fill_normal_ref<R>(
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -32,8 +32,8 @@ pub fn vec_znx_fill_normal_ref<R>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_fill_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
@@ -42,8 +42,15 @@ pub fn vec_znx_fill_normal_ref<R>(
)
}
pub fn vec_znx_add_normal_ref<R>(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source)
where
pub fn vec_znx_add_normal_ref<R>(
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
sigma: f64,
bound: f64,
source: &mut Source,
) where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -53,8 +60,8 @@ where
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,

View File

@@ -20,7 +20,7 @@ pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_lsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_lsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
@@ -35,8 +35,8 @@ where
let n: usize = res.n();
let cols: usize = res.cols();
let size: usize = res.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
let steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if steps >= size {
for j in 0..size {
@@ -45,7 +45,7 @@ where
return;
}
// Inplace shift of limbs by a k/basek
// Inplace shift of limbs by a k/base2k
if steps > 0 {
let start: usize = n * res_col;
let end: usize = start + n;
@@ -65,21 +65,21 @@ where
}
}
// Inplace normalization with left shift of k % basek
if !k.is_multiple_of(basek) {
// Inplace normalization with left shift of k % base2k
if !k.is_multiple_of(base2k) {
for j in (0..size - steps).rev() {
if j == size - steps - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
}
}
}
}
pub fn vec_znx_lsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
pub fn vec_znx_lsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -90,8 +90,8 @@ where
let res_size: usize = res.size();
let a_size = a.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
let steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if steps >= res_size.min(a_size) {
for j in 0..res_size {
@@ -103,12 +103,12 @@ where
let min_size: usize = a_size.min(res_size) - steps;
// Simply a left shifted normalization of limbs
// by k/basek and intra-limb by basek - k%basek
if !k.is_multiple_of(basek) {
// by k/base2k and intra-limb by base2k - k%base2k
if !k.is_multiple_of(base2k) {
for j in (0..min_size).rev() {
if j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -116,7 +116,7 @@ where
);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -124,7 +124,7 @@ where
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
base2k,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
@@ -133,7 +133,7 @@ where
}
}
} else {
// If k % basek = 0, then this is simply a copy.
// If k % base2k = 0, then this is simply a copy.
for j in (0..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
}
@@ -149,7 +149,7 @@ pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_rsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
@@ -166,8 +166,8 @@ where
let cols: usize = res.cols();
let size: usize = res.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
return;
@@ -184,8 +184,8 @@ where
let end: usize = start + n;
let slice_size: usize = n * cols;
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
@@ -194,9 +194,9 @@ where
// but the carry still need to be computed.
(size - steps..size).rev().for_each(|j| {
if j == size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
}
});
@@ -206,20 +206,20 @@ where
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
let rhs_slice: &mut [i64] = &mut rhs[start..end];
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry);
ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry);
});
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
// Shift by multiples of basek
// Shift by multiples of base2k
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
@@ -236,7 +236,7 @@ where
}
}
pub fn vec_znx_rsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
pub fn vec_znx_rsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
@@ -256,8 +256,8 @@ where
let res_size: usize = res.size();
let a_size: usize = a.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
let mut steps: usize = k / base2k;
let k_rem: usize = k % base2k;
if k == 0 {
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
@@ -271,8 +271,8 @@ where
return;
}
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
if !k.is_multiple_of(base2k) {
// We rsh by an additional base2k and then lsh by base2k-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
@@ -281,9 +281,9 @@ where
// but the carry still need to be computed.
for j in (res_size..a_size + steps).rev() {
if j == a_size + steps - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
}
}
@@ -300,16 +300,16 @@ where
// Case if no limb of a was previously discarded
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
basek - k_rem,
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
basek - k_rem,
base2k,
base2k - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
@@ -321,9 +321,9 @@ where
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
@@ -351,7 +351,7 @@ where
Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh_inplace::{}", label);
let group_name: String = format!("vec_znx_lsh_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -366,7 +366,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -381,7 +381,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
module.vec_znx_lsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
@@ -401,7 +401,7 @@ where
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh::{}", label);
let group_name: String = format!("vec_znx_lsh::{label}");
let mut group = c.benchmark_group(group_name);
@@ -416,7 +416,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -431,7 +431,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_lsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
@@ -451,7 +451,7 @@ where
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh_inplace::{}", label);
let group_name: String = format!("vec_znx_rsh_inplace::{label}");
let mut group = c.benchmark_group(group_name);
@@ -466,7 +466,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -481,7 +481,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
module.vec_znx_rsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
@@ -501,7 +501,7 @@ where
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh::{}", label);
let group_name: String = format!("vec_znx_rsh::{label}");
let mut group = c.benchmark_group(group_name);
@@ -516,7 +516,7 @@ where
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -531,7 +531,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
module.vec_znx_rsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
@@ -553,7 +553,7 @@ mod tests {
reference::{
vec_znx::{
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
vec_znx_sub_ab_inplace,
vec_znx_sub_inplace,
},
znx::ZnxRef,
},
@@ -574,20 +574,20 @@ mod tests {
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let base2k: usize = 50;
for k in 0..256 {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
for i in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(base2k, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, i, &mut carry);
}
assert_eq!(res_ref, res_test);
@@ -606,7 +606,7 @@ mod tests {
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let base2k: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
@@ -615,29 +615,29 @@ mod tests {
for a_size in [res_size - 1, res_size, res_size + 1] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
for k in 0..res_size * basek {
for k in 0..res_size * base2k {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
res_test.fill_uniform(50, &mut source);
for j in 0..cols {
vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry);
vec_znx_rsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(base2k, k, &mut res_test, j, &a, j, &mut carry);
}
for j in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_test, j, &mut carry);
}
// Case where res has enough to fully store a right shifted without any loss
// In this case we can check exact equality.
if a_size + k.div_ceil(basek) <= res_size {
if a_size + k.div_ceil(base2k) <= res_size {
assert_eq!(res_ref, res_test);
for i in 0..cols {
@@ -656,14 +656,14 @@ mod tests {
// res.
} else {
for j in 0..cols {
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry);
assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_ref.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
assert!(res_test.std(base2k, j).log2() - (k as f64) <= (k * base2k) as f64);
}
}
}

View File

@@ -3,10 +3,10 @@ use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace},
api::{ModuleNew, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace},
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl},
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
oep::{ModuleNewImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl},
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxZero},
source::Source,
};
@@ -64,11 +64,11 @@ where
}
}
pub fn vec_znx_sub_ab_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_sub_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubABInplace,
ZNXARI: ZnxSubInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -84,15 +84,15 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
ZNXARI::znx_sub_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_sub_ba_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
pub fn vec_znx_sub_negate_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubBAInplace + ZnxNegateInplace,
ZNXARI: ZnxSubNegateInplace + ZnxNegateInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -108,7 +108,7 @@ where
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
ZNXARI::znx_sub_negate_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in sum_size..res_size {
@@ -120,7 +120,7 @@ pub fn bench_vec_znx_sub<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubImpl<B>,
{
let group_name: String = format!("vec_znx_sub::{}", label);
let group_name: String = format!("vec_znx_sub::{label}");
let mut group = c.benchmark_group(group_name);
@@ -161,17 +161,17 @@ where
group.finish();
}
pub fn bench_vec_znx_sub_ab_inplace<B>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_sub_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubABInplaceImpl<B>,
B: Backend + ModuleNewImpl<B> + VecZnxSubInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label);
let group_name: String = format!("vec_znx_sub_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubABInplace + ModuleNew<B>,
Module<B>: VecZnxSubInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -190,7 +190,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_sub_ab_inplace(&mut b, i, &a, i);
module.vec_znx_sub_inplace(&mut b, i, &a, i);
}
black_box(());
}
@@ -205,17 +205,17 @@ where
group.finish();
}
pub fn bench_vec_znx_sub_ba_inplace<B>(c: &mut Criterion, label: &str)
pub fn bench_vec_znx_sub_negate_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubBAInplaceImpl<B>,
B: Backend + ModuleNewImpl<B> + VecZnxSubNegateInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label);
let group_name: String = format!("vec_znx_sub_negate_inplace::{label}");
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubBAInplace + ModuleNew<B>,
Module<B>: VecZnxSubNegateInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
@@ -234,7 +234,7 @@ where
move || {
for i in 0..cols {
module.vec_znx_sub_ba_inplace(&mut b, i, &a, i);
module.vec_znx_sub_negate_inplace(&mut b, i, &a, i);
}
black_box(());
}

View File

@@ -1,7 +1,7 @@
use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef};
use crate::{
layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero},
reference::znx::{ZnxSub, ZnxSubInplace, ZnxZero},
};
pub fn vec_znx_sub_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
@@ -19,12 +19,7 @@ where
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
assert!(b_limb < min_size, "b_limb: {b_limb} > min_size: {min_size}");
}
for j in 0..min_size {
@@ -44,7 +39,7 @@ pub fn vec_znx_sub_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res
where
R: VecZnxToMut,
A: ScalarZnxToRef,
ZNXARI: ZnxSubABInplace,
ZNXARI: ZnxSubInplace,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
@@ -54,5 +49,5 @@ where
assert!(res_limb < res.size());
}
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
ZNXARI::znx_sub_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
}

View File

@@ -9,7 +9,7 @@ pub fn zn_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn zn_normalize_inplace<R, ARI>(n: usize, basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
pub fn zn_normalize_inplace<R, ARI>(n: usize, base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: ZnToMut,
ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace,
@@ -27,11 +27,11 @@ where
let out = &mut res.at_mut(res_col, j)[..n];
if j == res_size - 1 {
ARI::znx_normalize_first_step_inplace(basek, 0, out, carry);
ARI::znx_normalize_first_step_inplace(base2k, 0, out, carry);
} else if j == 0 {
ARI::znx_normalize_final_step_inplace(basek, 0, out, carry);
ARI::znx_normalize_final_step_inplace(base2k, 0, out, carry);
} else {
ARI::znx_normalize_middle_step_inplace(basek, 0, out, carry);
ARI::znx_normalize_middle_step_inplace(base2k, 0, out, carry);
}
}
}
@@ -43,7 +43,7 @@ where
{
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let basek: usize = 12;
let base2k: usize = 12;
let n = 33;
@@ -63,8 +63,8 @@ where
// Reference
for i in 0..cols {
zn_normalize_inplace::<_, ZnxRef>(n, basek, &mut res_0, i, &mut carry);
module.zn_normalize_inplace(n, basek, &mut res_1, i, scratch.borrow());
zn_normalize_inplace::<_, ZnxRef>(n, base2k, &mut res_0, i, &mut carry);
module.zn_normalize_inplace(n, base2k, &mut res_1, i, scratch.borrow());
}
assert_eq!(res_0.raw(), res_1.raw());

View File

@@ -4,20 +4,20 @@ use crate::{
source::Source,
};
pub fn zn_fill_uniform<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
pub fn zn_fill_uniform<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(basek, &mut res.at_mut(res_col, j)[..n], source)
znx_fill_uniform_ref(base2k, &mut res.at_mut(res_col, j)[..n], source)
}
}
#[allow(clippy::too_many_arguments)]
pub fn zn_fill_normal<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -34,8 +34,8 @@ pub fn zn_fill_normal<R>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_fill_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,
@@ -47,7 +47,7 @@ pub fn zn_fill_normal<R>(
#[allow(clippy::too_many_arguments)]
pub fn zn_add_normal<R>(
n: usize,
basek: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
@@ -64,8 +64,8 @@ pub fn zn_add_normal<R>(
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,

View File

@@ -1,8 +1,9 @@
use crate::reference::znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubABInplace,
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero,
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulAddPowerOfTwo, ZnxMulPowerOfTwo,
ZnxMulPowerOfTwoInplace, ZnxNegate, ZnxNegateInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep,
ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubInplace, ZnxSubNegateInplace, ZnxSwitchRing,
ZnxZero,
add::{znx_add_inplace_ref, znx_add_ref},
automorphism::znx_automorphism_ref,
copy::znx_copy_ref,
@@ -12,9 +13,11 @@ use crate::reference::znx::{
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_carry_only_ref,
znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
},
sub::{znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref},
sub::{znx_sub_inplace_ref, znx_sub_negate_inplace_ref, znx_sub_ref},
switch_ring::znx_switch_ring_ref,
zero::znx_zero_ref,
znx_extract_digit_addmul_ref, znx_mul_add_power_of_two_ref, znx_mul_power_of_two_inplace_ref, znx_mul_power_of_two_ref,
znx_normalize_digit_ref,
};
pub struct ZnxRef {}
@@ -40,17 +43,17 @@ impl ZnxSub for ZnxRef {
}
}
impl ZnxSubABInplace for ZnxRef {
impl ZnxSubInplace for ZnxRef {
#[inline(always)]
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ab_inplace_ref(res, a);
fn znx_sub_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_inplace_ref(res, a);
}
}
impl ZnxSubBAInplace for ZnxRef {
impl ZnxSubNegateInplace for ZnxRef {
#[inline(always)]
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ba_inplace_ref(res, a);
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_negate_inplace_ref(res, a);
}
}
@@ -61,6 +64,27 @@ impl ZnxAutomorphism for ZnxRef {
}
}
impl ZnxMulPowerOfTwo for ZnxRef {
#[inline(always)]
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_power_of_two_ref(k, res, a);
}
}
impl ZnxMulAddPowerOfTwo for ZnxRef {
#[inline(always)]
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]) {
znx_mul_add_power_of_two_ref(k, res, a);
}
}
impl ZnxMulPowerOfTwoInplace for ZnxRef {
#[inline(always)]
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]) {
znx_mul_power_of_two_inplace_ref(k, res);
}
}
impl ZnxCopy for ZnxRef {
#[inline(always)]
fn znx_copy(res: &mut [i64], a: &[i64]) {
@@ -98,56 +122,70 @@ impl ZnxSwitchRing for ZnxRef {
impl ZnxNormalizeFinalStep for ZnxRef {
#[inline(always)]
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(base2k, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(base2k, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(base2k, lsh, x, carry);
}
}
impl ZnxExtractDigitAddMul for ZnxRef {
#[inline(always)]
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
znx_extract_digit_addmul_ref(base2k, lsh, res, src);
}
}
impl ZnxNormalizeDigit for ZnxRef {
#[inline(always)]
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]) {
znx_normalize_digit_ref(base2k, res, src);
}
}

View File

@@ -2,6 +2,7 @@ mod add;
mod arithmetic_ref;
mod automorphism;
mod copy;
mod mul;
mod neg;
mod normalization;
mod rotate;
@@ -14,6 +15,7 @@ pub use add::*;
pub use arithmetic_ref::*;
pub use automorphism::*;
pub use copy::*;
pub use mul::*;
pub use neg::*;
pub use normalization::*;
pub use rotate::*;
@@ -35,12 +37,12 @@ pub trait ZnxSub {
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]);
}
pub trait ZnxSubABInplace {
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]);
pub trait ZnxSubInplace {
fn znx_sub_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxSubBAInplace {
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]);
pub trait ZnxSubNegateInplace {
fn znx_sub_negate_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxAutomorphism {
@@ -67,38 +69,58 @@ pub trait ZnxZero {
fn znx_zero(res: &mut [i64]);
}
pub trait ZnxMulPowerOfTwo {
fn znx_mul_power_of_two(k: i64, res: &mut [i64], a: &[i64]);
}
pub trait ZnxMulAddPowerOfTwo {
fn znx_muladd_power_of_two(k: i64, res: &mut [i64], a: &[i64]);
}
pub trait ZnxMulPowerOfTwoInplace {
fn znx_mul_power_of_two_inplace(k: i64, res: &mut [i64]);
}
pub trait ZnxSwitchRing {
fn znx_switch_ring(res: &mut [i64], a: &[i64]);
}
pub trait ZnxNormalizeFirstStepCarryOnly {
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
fn znx_normalize_first_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFirstStepInplace {
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
fn znx_normalize_first_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFirstStep {
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
fn znx_normalize_first_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStepCarryOnly {
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
fn znx_normalize_middle_step_carry_only(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStepInplace {
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
fn znx_normalize_middle_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStep {
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
fn znx_normalize_middle_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFinalStepInplace {
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
fn znx_normalize_final_step_inplace(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFinalStep {
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
fn znx_normalize_final_step(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxExtractDigitAddMul {
fn znx_extract_digit_addmul(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]);
}
pub trait ZnxNormalizeDigit {
fn znx_normalize_digit(base2k: usize, res: &mut [i64], src: &mut [i64]);
}

View File

@@ -0,0 +1,76 @@
use crate::reference::znx::{znx_add_inplace_ref, znx_copy_ref};
pub fn znx_mul_power_of_two_ref(mut k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
if k == 0 {
znx_copy_ref(res, a);
return;
}
if k > 0 {
for (y, x) in res.iter_mut().zip(a.iter()) {
*y = *x << k
}
return;
}
k = -k;
for (y, x) in res.iter_mut().zip(a.iter()) {
let sign_bit: i64 = (x >> 63) & 1;
let bias: i64 = (1_i64 << (k - 1)) - sign_bit;
*y = (x + bias) >> k;
}
}
pub fn znx_mul_power_of_two_inplace_ref(mut k: i64, res: &mut [i64]) {
if k == 0 {
return;
}
if k > 0 {
for x in res.iter_mut() {
*x <<= k
}
return;
}
k = -k;
for x in res.iter_mut() {
let sign_bit: i64 = (*x >> 63) & 1;
let bias: i64 = (1_i64 << (k - 1)) - sign_bit;
*x = (*x + bias) >> k;
}
}
pub fn znx_mul_add_power_of_two_ref(mut k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
if k == 0 {
znx_add_inplace_ref(res, a);
return;
}
if k > 0 {
for (y, x) in res.iter_mut().zip(a.iter()) {
*y += *x << k
}
return;
}
k = -k;
for (y, x) in res.iter_mut().zip(a.iter()) {
let sign_bit: i64 = (x >> 63) & 1;
let bias: i64 = (1_i64 << (k - 1)) - sign_bit;
*y += (x + bias) >> k;
}
}

View File

@@ -1,199 +1,229 @@
use itertools::izip;
#[inline(always)]
pub fn get_digit(basek: usize, x: i64) -> i64 {
(x << (u64::BITS - basek as u32)) >> (u64::BITS - basek as u32)
pub fn get_digit_i64(base2k: usize, x: i64) -> i64 {
(x << (u64::BITS - base2k as u32)) >> (u64::BITS - base2k as u32)
}
#[inline(always)]
pub fn get_carry(basek: usize, x: i64, digit: i64) -> i64 {
(x.wrapping_sub(digit)) >> basek
pub fn get_carry_i64(base2k: usize, x: i64, digit: i64) -> i64 {
(x.wrapping_sub(digit)) >> base2k
}
#[inline(always)]
pub fn znx_normalize_first_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
pub fn get_digit_i128(base2k: usize, x: i128) -> i128 {
(x << (u128::BITS - base2k as u32)) >> (u128::BITS - base2k as u32)
}
#[inline(always)]
pub fn get_carry_i128(base2k: usize, x: i128, digit: i128) -> i128 {
(x.wrapping_sub(digit)) >> base2k
}
#[inline(always)]
pub fn znx_normalize_first_step_carry_only_ref(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
*c = get_carry(basek, *x, get_digit(basek, *x));
*c = get_carry_i64(base2k, *x, get_digit_i64(base2k, *x));
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
*c = get_carry(basek_lsh, *x, get_digit(basek_lsh, *x));
*c = get_carry_i64(basek_lsh, *x, get_digit_i64(basek_lsh, *x));
});
}
}
#[inline(always)]
pub fn znx_normalize_first_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_first_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
*c = get_carry(basek, *x, digit);
let digit: i64 = get_digit_i64(base2k, *x);
*c = get_carry_i64(base2k, *x, digit);
*x = digit;
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
*c = get_carry(basek_lsh, *x, digit);
let digit: i64 = get_digit_i64(basek_lsh, *x);
*c = get_carry_i64(basek_lsh, *x, digit);
*x = digit << lsh;
});
}
}
#[inline(always)]
pub fn znx_normalize_first_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_first_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek, *a);
*c = get_carry(basek, *a, digit);
let digit: i64 = get_digit_i64(base2k, *a);
*c = get_carry_i64(base2k, *a, digit);
*x = digit;
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek_lsh, *a);
*c = get_carry(basek_lsh, *a, digit);
let digit: i64 = get_digit_i64(basek_lsh, *a);
*c = get_carry_i64(basek_lsh, *a, digit);
*x = digit << lsh;
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_carry_only_ref(base2k: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
let carry: i64 = get_carry(basek, *x, digit);
let digit: i64 = get_digit_i64(base2k, *x);
let carry: i64 = get_carry_i64(base2k, *x, digit);
let digit_plus_c: i64 = digit + *c;
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
*c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c));
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
let carry: i64 = get_carry(basek_lsh, *x, digit);
let digit: i64 = get_digit_i64(basek_lsh, *x);
let carry: i64 = get_carry_i64(basek_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
*c = carry + get_carry_i64(base2k, digit_plus_c, get_digit_i64(base2k, digit_plus_c));
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_middle_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
let carry: i64 = get_carry(basek, *x, digit);
let digit: i64 = get_digit_i64(base2k, *x);
let carry: i64 = get_carry_i64(base2k, *x, digit);
let digit_plus_c: i64 = digit + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
let carry: i64 = get_carry(basek_lsh, *x, digit);
let digit: i64 = get_digit_i64(basek_lsh, *x);
let carry: i64 = get_carry_i64(basek_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_extract_digit_addmul_ref(base2k: usize, lsh: usize, res: &mut [i64], src: &mut [i64]) {
for (r, s) in res.iter_mut().zip(src.iter_mut()) {
let digit: i64 = get_digit_i64(base2k, *s);
*s = get_carry_i64(base2k, *s, digit);
*r += digit << lsh;
}
}
#[inline(always)]
pub fn znx_normalize_digit_ref(base2k: usize, res: &mut [i64], src: &mut [i64]) {
for (r, s) in res.iter_mut().zip(src.iter_mut()) {
let ri_digit: i64 = get_digit_i64(base2k, *r);
let ri_carry: i64 = get_carry_i64(base2k, *r, ri_digit);
*r = ri_digit;
*s += ri_carry;
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek, *a);
let carry: i64 = get_carry(basek, *a, digit);
let digit: i64 = get_digit_i64(base2k, *a);
let carry: i64 = get_carry_i64(base2k, *a, digit);
let digit_plus_c: i64 = digit + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek_lsh, *a);
let carry: i64 = get_carry(basek_lsh, *a, digit);
let digit: i64 = get_digit_i64(basek_lsh, *a);
let carry: i64 = get_carry_i64(basek_lsh, *a, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
*x = get_digit_i64(base2k, digit_plus_c);
*c = carry + get_carry_i64(base2k, digit_plus_c, *x);
});
}
}
#[inline(always)]
pub fn znx_normalize_final_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
pub fn znx_normalize_final_step_inplace_ref(base2k: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
*x = get_digit(basek, get_digit(basek, *x) + *c);
*x = get_digit_i64(base2k, get_digit_i64(base2k, *x) + *c);
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
*x = get_digit(basek, (get_digit(basek_lsh, *x) << lsh) + *c);
*x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *x) << lsh) + *c);
});
}
}
#[inline(always)]
pub fn znx_normalize_final_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
pub fn znx_normalize_final_step_ref(base2k: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
assert!(lsh < base2k);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
*x = get_digit(basek, get_digit(basek, *a) + *c);
*x = get_digit_i64(base2k, get_digit_i64(base2k, *a) + *c);
});
} else {
let basek_lsh: usize = basek - lsh;
let basek_lsh: usize = base2k - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
*x = get_digit(basek, (get_digit(basek_lsh, *a) << lsh) + *c);
*x = get_digit_i64(base2k, (get_digit_i64(basek_lsh, *a) << lsh) + *c);
});
}
}

View File

@@ -2,8 +2,8 @@ use rand_distr::{Distribution, Normal};
use crate::source::Source;
pub fn znx_fill_uniform_ref(basek: usize, res: &mut [i64], source: &mut Source) {
let pow2k: u64 = 1 << basek;
pub fn znx_fill_uniform_ref(base2k: usize, res: &mut [i64], source: &mut Source) {
let pow2k: u64 = 1 << base2k;
let mask: u64 = pow2k - 1;
let pow2k_half: i64 = (pow2k >> 1) as i64;
res.iter_mut()

View File

@@ -11,7 +11,7 @@ pub fn znx_sub_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
}
}
pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) {
pub fn znx_sub_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
@@ -23,7 +23,7 @@ pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) {
}
}
pub fn znx_sub_ba_inplace_ref(res: &mut [i64], a: &[i64]) {
pub fn znx_sub_negate_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());

View File

@@ -48,6 +48,16 @@ impl Source {
pub fn next_i64(&mut self) -> i64 {
self.next_u64() as i64
}
#[inline(always)]
pub fn next_i128(&mut self) -> i128 {
self.next_u128() as i128
}
#[inline(always)]
pub fn next_u128(&mut self) -> u128 {
(self.next_u64() as u128) << 64 | (self.next_u64() as u128)
}
}
impl RngCore for Source {

View File

@@ -41,7 +41,7 @@ macro_rules! cross_backend_test_suite {
backend_ref = $backend_ref:ty,
backend_test = $backend_test:ty,
size = $size:expr,
basek = $basek:expr,
base2k = $base2k:expr,
tests = {
$( $(#[$attr:meta])* $test_name:ident => $impl:path ),+ $(,)?
}
@@ -60,7 +60,7 @@ macro_rules! cross_backend_test_suite {
$(#[$attr])*
#[test]
fn $test_name() {
($impl)($basek, &*MODULE_REF, &*MODULE_TEST);
($impl)($base2k, &*MODULE_REF, &*MODULE_TEST);
}
)+
}

View File

@@ -1,7 +1,7 @@
use std::fmt::Debug;
use crate::{
layouts::{FillUniform, ReaderFrom, Reset, WriterTo},
layouts::{FillUniform, ReaderFrom, WriterTo},
source::Source,
};
@@ -10,7 +10,7 @@ use crate::{
/// - `T` must implement I/O traits, zeroing, cloning, and random filling.
pub fn test_reader_writer_interface<T>(mut original: T)
where
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + Reset + FillUniform,
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + FillUniform,
{
// Fill original with uniform random data
let mut source = Source::new([0u8; 32]);
@@ -20,9 +20,9 @@ where
let mut buffer = Vec::new();
original.write_to(&mut buffer).expect("write_to failed");
// Prepare receiver: same shape, but zeroed
// Prepare receiver: same shape, but randomized
let mut receiver = original.clone();
receiver.reset();
receiver.fill_uniform(50, &mut source);
// Deserialize from buffer
let mut reader: &[u8] = &buffer;

View File

@@ -10,7 +10,7 @@ use crate::{
source::Source,
};
pub fn test_svp_apply_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_svp_apply_dft<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDft<BR>
@@ -40,7 +40,7 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
scalar.fill_uniform(base2k, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
@@ -60,7 +60,7 @@ where
for a_size in [1, 2, 3, 4] {
// Create a random input VecZnx
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
@@ -91,17 +91,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -113,7 +115,7 @@ where
}
}
pub fn test_svp_apply_dft_to_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_svp_apply_dft_to_dft<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDftToDft<BR>
@@ -145,7 +147,7 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
scalar.fill_uniform(base2k, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
@@ -165,7 +167,7 @@ where
for a_size in [3] {
// Create a random input VecZnx
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
@@ -211,17 +213,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -233,7 +237,7 @@ where
}
}
pub fn test_svp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_svp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: SvpPrepare<BR>
+ SvpApplyDftToDftAdd<BR>
@@ -265,7 +269,7 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
scalar.fill_uniform(base2k, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
@@ -285,7 +289,7 @@ where
for a_size in [1, 2, 3, 4] {
// Create a random input VecZnx
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
@@ -302,7 +306,7 @@ where
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
let mut res_dft_test: VecZnxDft<Vec<u8>, BT> = module_test.vec_znx_dft_alloc(cols, res_size);
@@ -336,17 +340,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -359,7 +365,7 @@ where
}
pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
basek: usize,
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
@@ -393,7 +399,7 @@ pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
scalar.fill_uniform(basek, &mut source);
scalar.fill_uniform(base2k, &mut source);
let scalar_digest: u64 = scalar.digest_u64();
@@ -412,7 +418,7 @@ pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
@@ -442,17 +448,19 @@ pub fn test_svp_apply_dft_to_dft_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),

View File

@@ -8,38 +8,18 @@ use crate::{
VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes,
VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxSplitRing, VecZnxSplitRingTmpBytes,
VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing,
},
layouts::{Backend, DigestU64, FillUniform, Module, ScalarZnx, ScratchOwned, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::znx_copy_ref,
source::Source,
};
pub fn test_vec_znx_encode_vec_i64_lo_norm() {
pub fn test_vec_znx_encode_vec_i64() {
let n: usize = 32;
let basek: usize = 17;
let base2k: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
a.encode_vec_i64(basek, col_i, k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
});
}
pub fn test_vec_znx_encode_vec_i64_hi_norm() {
let n: usize = 32;
let basek: usize = 17;
let size: usize = 5;
for k in [1, basek / 2, size * basek - 5] {
for k in [1, base2k / 2, size * base2k - 5] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
@@ -53,15 +33,15 @@ pub fn test_vec_znx_encode_vec_i64_hi_norm() {
*x = source.next_i64();
}
});
a.encode_vec_i64(basek, col_i, k, &have, 63);
a.encode_vec_i64(base2k, col_i, k, &have);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(basek, col_i, k, &mut want);
a.decode_vec_i64(base2k, col_i, k, &mut want);
assert_eq!(have, want, "{:?} != {:?}", &have, &want);
})
}
}
pub fn test_vec_znx_add_scalar<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_add_scalar<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAddScalar,
Module<BT>: VecZnxAddScalar,
@@ -74,12 +54,12 @@ where
let cols: usize = 2;
let mut a: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
for a_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -87,8 +67,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
rest_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
rest_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -103,7 +83,7 @@ where
}
}
pub fn test_vec_znx_add_scalar_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_add_scalar_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAddScalarInplace,
Module<BT>: VecZnxAddScalarInplace,
@@ -116,14 +96,14 @@ where
let cols: usize = 2;
let mut b: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut rest_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
rest_ref.fill_uniform(basek, &mut source);
rest_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(rest_ref.raw());
for i in 0..cols {
@@ -135,7 +115,7 @@ where
assert_eq!(rest_ref, res_test);
}
}
pub fn test_vec_znx_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_add<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAdd,
Module<BT>: VecZnxAdd,
@@ -148,13 +128,13 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
@@ -163,8 +143,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -181,7 +161,7 @@ where
}
}
pub fn test_vec_znx_add_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_add_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAddInplace,
Module<BT>: VecZnxAddInplace,
@@ -194,14 +174,14 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
@@ -215,7 +195,7 @@ where
}
}
pub fn test_vec_znx_automorphism<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_automorphism<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxAutomorphism,
Module<BT>: VecZnxAutomorphism,
@@ -228,7 +208,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -261,7 +241,7 @@ where
}
pub fn test_vec_znx_automorphism_inplace<BR: Backend, BT: Backend>(
basek: usize,
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
@@ -284,7 +264,7 @@ pub fn test_vec_znx_automorphism_inplace<BR: Backend, BT: Backend>(
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
znx_copy_ref(res_test.raw_mut(), res_ref.raw());
let p: i64 = -7;
@@ -309,7 +289,7 @@ pub fn test_vec_znx_automorphism_inplace<BR: Backend, BT: Backend>(
}
}
pub fn test_vec_znx_copy<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_copy<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxCopy,
Module<BT>: VecZnxCopy,
@@ -322,7 +302,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -330,8 +310,8 @@ where
let mut res_1: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_0.fill_uniform(basek, &mut source);
res_1.fill_uniform(basek, &mut source);
res_0.fill_uniform(base2k, &mut source);
res_1.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -345,7 +325,7 @@ where
}
}
pub fn test_vec_znx_merge_rings<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_merge_rings<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxMergeRings<BR> + ModuleNew<BR> + VecZnxMergeRingsTmpBytes,
Module<BT>: VecZnxMergeRings<BT> + ModuleNew<BT> + VecZnxMergeRingsTmpBytes,
@@ -367,7 +347,7 @@ where
];
a.iter_mut().for_each(|ai| {
ai.fill_uniform(basek, &mut source);
ai.fill_uniform(base2k, &mut source);
});
let a_digests: [u64; 2] = [a[0].digest_u64(), a[1].digest_u64()];
@@ -376,8 +356,8 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
for i in 0..cols {
module_ref.vec_znx_merge_rings(&mut res_test, i, &a, i, scratch_ref.borrow());
@@ -390,7 +370,7 @@ where
}
}
pub fn test_vec_znx_mul_xp_minus_one<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_mul_xp_minus_one<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxMulXpMinusOne,
Module<BT>: VecZnxMulXpMinusOne,
@@ -403,7 +383,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
@@ -437,7 +417,7 @@ where
}
pub fn test_vec_znx_mul_xp_minus_one_inplace<BR: Backend, BT: Backend>(
basek: usize,
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
@@ -460,7 +440,7 @@ pub fn test_vec_znx_mul_xp_minus_one_inplace<BR: Backend, BT: Backend>(
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
znx_copy_ref(res_test.raw_mut(), res_ref.raw());
let p: i64 = -7;
@@ -483,7 +463,7 @@ pub fn test_vec_znx_mul_xp_minus_one_inplace<BR: Backend, BT: Backend>(
}
}
pub fn test_vec_znx_negate<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_negate<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxNegate,
Module<BT>: VecZnxNegate,
@@ -496,14 +476,14 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
@@ -517,7 +497,7 @@ where
}
}
pub fn test_vec_znx_negate_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_negate_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxNegateInplace,
Module<BT>: VecZnxNegateInplace,
@@ -532,7 +512,7 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
@@ -544,7 +524,7 @@ where
}
}
pub fn test_vec_znx_normalize<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_normalize<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxNormalize<BR> + VecZnxNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -562,7 +542,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -570,13 +550,21 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize(basek, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_normalize(basek, &mut res_test, i, &a, i, scratch_test.borrow());
module_ref.vec_znx_normalize(base2k, &mut res_ref, i, base2k, &a, i, scratch_ref.borrow());
module_test.vec_znx_normalize(
base2k,
&mut res_test,
i,
base2k,
&a,
i,
scratch_test.borrow(),
);
}
assert_eq!(a.digest_u64(), a_digest);
@@ -585,7 +573,7 @@ where
}
}
pub fn test_vec_znx_normalize_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_normalize_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxNormalizeInplace<BR> + VecZnxNormalizeTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -605,20 +593,20 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
// Reference
for i in 0..cols {
module_ref.vec_znx_normalize_inplace(basek, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_normalize_inplace(basek, &mut res_test, i, scratch_test.borrow());
module_ref.vec_znx_normalize_inplace(base2k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_normalize_inplace(base2k, &mut res_test, i, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
}
}
pub fn test_vec_znx_rotate<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_rotate<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRotate,
Module<BT>: VecZnxRotate,
@@ -631,7 +619,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -663,7 +651,7 @@ where
}
}
pub fn test_vec_znx_rotate_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_rotate_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRotateInplace<BR> + VecZnxRotateInplaceTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -684,7 +672,7 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
znx_copy_ref(res_test.raw_mut(), res_ref.raw());
let p: i64 = -5;
@@ -714,7 +702,7 @@ where
Module<B>: VecZnxFillUniform,
{
let n: usize = module.n();
let basek: usize = 17;
let base2k: usize = 17;
let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
@@ -722,19 +710,17 @@ where
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_uniform(basek, &mut a, col_i, &mut source);
module.vec_znx_fill_uniform(base2k, &mut a, col_i, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i);
let std: f64 = a.std(base2k, col_i);
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={} ~!= {}",
std,
one_12_sqrt
"std={std} ~!= {one_12_sqrt}",
);
}
})
@@ -746,7 +732,7 @@ where
Module<B>: VecZnxFillNormal,
{
let n: usize = module.n();
let basek: usize = 17;
let base2k: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
@@ -757,15 +743,15 @@ where
let k_f64: f64 = (1u64 << k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_fill_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
let std: f64 = a.std(base2k, col_i) * k_f64;
assert!((std - sigma).abs() < 0.1, "std={std} ~!= {sigma}");
}
})
});
@@ -776,7 +762,7 @@ where
Module<B>: VecZnxFillNormal + VecZnxAddNormal,
{
let n: usize = module.n();
let basek: usize = 17;
let base2k: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
@@ -788,19 +774,18 @@ where
let sqrt2: f64 = SQRT_2;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_fill_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_add_normal(base2k, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
let std: f64 = a.std(base2k, col_i) * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}",
std,
"std={std} ~!= {}",
sigma * sqrt2
);
}
@@ -808,7 +793,7 @@ where
});
}
pub fn test_vec_znx_lsh<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_lsh<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxLsh<BR> + VecZnxLshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -826,22 +811,22 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for k in 0..res_size * basek {
for k in 0..res_size * base2k {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
module_ref.vec_znx_lsh(basek, k, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_lsh(basek, k, &mut res_test, i, &a, i, scratch_test.borrow());
module_ref.vec_znx_lsh(base2k, k, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_lsh(base2k, k, &mut res_test, i, &a, i, scratch_test.borrow());
}
assert_eq!(a.digest_u64(), a_digest);
@@ -851,7 +836,7 @@ where
}
}
pub fn test_vec_znx_lsh_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_lsh_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxLshInplace<BR> + VecZnxLshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -868,16 +853,16 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
for res_size in [1, 2, 3, 4] {
for k in 0..basek * res_size {
for k in 0..base2k * res_size {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
module_ref.vec_znx_lsh_inplace(basek, k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_lsh_inplace(basek, k, &mut res_test, i, scratch_test.borrow());
module_ref.vec_znx_lsh_inplace(base2k, k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_lsh_inplace(base2k, k, &mut res_test, i, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -885,7 +870,7 @@ where
}
}
pub fn test_vec_znx_rsh<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_rsh<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRsh<BR> + VecZnxLshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -902,22 +887,22 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
for k in 0..res_size * basek {
for k in 0..res_size * base2k {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
module_ref.vec_znx_rsh(basek, k, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_rsh(basek, k, &mut res_test, i, &a, i, scratch_test.borrow());
module_ref.vec_znx_rsh(base2k, k, &mut res_ref, i, &a, i, scratch_ref.borrow());
module_test.vec_znx_rsh(base2k, k, &mut res_test, i, &a, i, scratch_test.borrow());
}
assert_eq!(a.digest_u64(), a_digest);
@@ -927,7 +912,7 @@ where
}
}
pub fn test_vec_znx_rsh_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_rsh_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxRshInplace<BR> + VecZnxLshTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -943,16 +928,16 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_lsh_tmp_bytes());
for res_size in [1, 2, 3, 4] {
for k in 0..basek * res_size {
for k in 0..base2k * res_size {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
module_ref.vec_znx_rsh_inplace(basek, k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_rsh_inplace(basek, k, &mut res_test, i, scratch_test.borrow());
module_ref.vec_znx_rsh_inplace(base2k, k, &mut res_ref, i, scratch_ref.borrow());
module_test.vec_znx_rsh_inplace(base2k, k, &mut res_test, i, scratch_test.borrow());
}
assert_eq!(res_ref, res_test);
@@ -960,7 +945,7 @@ where
}
}
pub fn test_vec_znx_split_ring<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_split_ring<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSplitRing<BR> + ModuleNew<BR> + VecZnxSplitRingTmpBytes,
ScratchOwned<BR>: ScratchOwnedAlloc<BR> + ScratchOwnedBorrow<BR>,
@@ -977,7 +962,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -992,11 +977,11 @@ where
];
res_ref.iter_mut().for_each(|ri| {
ri.fill_uniform(basek, &mut source);
ri.fill_uniform(base2k, &mut source);
});
res_test.iter_mut().for_each(|ri| {
ri.fill_uniform(basek, &mut source);
ri.fill_uniform(base2k, &mut source);
});
for i in 0..cols {
@@ -1013,7 +998,7 @@ where
}
}
pub fn test_vec_znx_sub_scalar<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub_scalar<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSubScalar,
Module<BT>: VecZnxSubScalar,
@@ -1025,12 +1010,12 @@ where
let cols: usize = 2;
let mut a: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -1038,8 +1023,8 @@ where
let mut res_1: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_0.fill_uniform(basek, &mut source);
res_1.fill_uniform(basek, &mut source);
res_0.fill_uniform(base2k, &mut source);
res_1.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -1054,7 +1039,7 @@ where
}
}
pub fn test_vec_znx_sub_scalar_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub_scalar_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSubScalarInplace,
Module<BT>: VecZnxSubScalarInplace,
@@ -1066,14 +1051,14 @@ where
let cols: usize = 2;
let mut a: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, cols);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_0: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_1: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_0.fill_uniform(basek, &mut source);
res_0.fill_uniform(base2k, &mut source);
res_1.raw_mut().copy_from_slice(res_0.raw());
for i in 0..cols {
@@ -1086,7 +1071,7 @@ where
}
}
pub fn test_vec_znx_sub<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSub,
Module<BT>: VecZnxSub,
@@ -1099,12 +1084,12 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -1112,8 +1097,8 @@ where
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
// Set d to garbage
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Reference
for i in 0..cols {
@@ -1130,10 +1115,10 @@ where
}
}
pub fn test_vec_znx_sub_ab_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSubABInplace,
Module<BT>: VecZnxSubABInplace,
Module<BR>: VecZnxSubInplace,
Module<BT>: VecZnxSubInplace,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
@@ -1143,19 +1128,19 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
module_test.vec_znx_sub_ab_inplace(&mut res_ref, i, &a, i);
module_ref.vec_znx_sub_ab_inplace(&mut res_test, i, &a, i);
module_test.vec_znx_sub_inplace(&mut res_ref, i, &a, i);
module_ref.vec_znx_sub_inplace(&mut res_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
@@ -1164,10 +1149,10 @@ where
}
}
pub fn test_vec_znx_sub_ba_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_sub_negate_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSubBAInplace,
Module<BT>: VecZnxSubBAInplace,
Module<BR>: VecZnxSubNegateInplace,
Module<BT>: VecZnxSubNegateInplace,
{
assert_eq!(module_ref.n(), module_test.n());
let n: usize = module_ref.n();
@@ -1177,19 +1162,19 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.raw_mut().copy_from_slice(res_ref.raw());
for i in 0..cols {
module_test.vec_znx_sub_ba_inplace(&mut res_ref, i, &a, i);
module_ref.vec_znx_sub_ba_inplace(&mut res_test, i, &a, i);
module_test.vec_znx_sub_negate_inplace(&mut res_ref, i, &a, i);
module_ref.vec_znx_sub_negate_inplace(&mut res_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
@@ -1198,7 +1183,7 @@ where
}
}
pub fn test_vec_znx_switch_ring<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_switch_ring<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxSwitchRing,
Module<BT>: VecZnxSwitchRing,
@@ -1213,7 +1198,7 @@ where
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
// Fill a with random i64
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -1221,8 +1206,8 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n << 1, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n << 1, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Normalize on c
for i in 0..cols {
@@ -1238,8 +1223,8 @@ where
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n >> 1, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n >> 1, cols, res_size);
res_ref.fill_uniform(basek, &mut source);
res_test.fill_uniform(basek, &mut source);
res_ref.fill_uniform(base2k, &mut source);
res_test.fill_uniform(base2k, &mut source);
// Normalize on c
for i in 0..cols {

View File

@@ -5,14 +5,14 @@ use crate::{
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAdd, VecZnxBigAddInplace, VecZnxBigAddSmall, VecZnxBigAddSmallInplace,
VecZnxBigAlloc, VecZnxBigAutomorphism, VecZnxBigAutomorphismInplace, VecZnxBigAutomorphismInplaceTmpBytes,
VecZnxBigFromSmall, VecZnxBigNegate, VecZnxBigNegateInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxBigSub, VecZnxBigSubABInplace, VecZnxBigSubBAInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallAInplace,
VecZnxBigSubSmallB, VecZnxBigSubSmallBInplace,
VecZnxBigSub, VecZnxBigSubInplace, VecZnxBigSubNegateInplace, VecZnxBigSubSmallA, VecZnxBigSubSmallB,
VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace,
},
layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig},
source::Source,
};
pub fn test_vec_znx_big_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_add<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>:
VecZnxBigAdd<BR> + VecZnxBigAlloc<BR> + VecZnxBigFromSmall<BR> + VecZnxBigNormalize<BR> + VecZnxBigNormalizeTmpBytes,
@@ -32,7 +32,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
@@ -50,7 +50,7 @@ where
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest = b.digest_u64();
let mut b_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, b_size);
@@ -93,17 +93,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -119,7 +121,7 @@ where
}
}
pub fn test_vec_znx_big_add_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_add_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxBigAddInplace<BR>
+ VecZnxBigAlloc<BR>
@@ -145,7 +147,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
let mut a_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, a_size);
@@ -160,7 +162,7 @@ where
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
@@ -186,17 +188,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -211,7 +215,7 @@ where
}
}
pub fn test_vec_znx_big_add_small<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_add_small<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>:
VecZnxBigAddSmall<BR> + VecZnxBigAlloc<BR> + VecZnxBigFromSmall<BR> + VecZnxBigNormalize<BR> + VecZnxBigNormalizeTmpBytes,
@@ -231,7 +235,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
let mut a_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, a_size);
@@ -246,7 +250,7 @@ where
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -275,17 +279,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -302,7 +308,7 @@ where
}
pub fn test_vec_znx_big_add_small_inplace<BR: Backend, BT: Backend>(
basek: usize,
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
@@ -330,13 +336,13 @@ pub fn test_vec_znx_big_add_small_inplace<BR: Backend, BT: Backend>(
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
@@ -361,17 +367,19 @@ pub fn test_vec_znx_big_add_small_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -386,7 +394,7 @@ pub fn test_vec_znx_big_add_small_inplace<BR: Backend, BT: Backend>(
}
}
pub fn test_vec_znx_big_automorphism<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_automorphism<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxBigAutomorphism<BR>
+ VecZnxBigAlloc<BR>
@@ -412,7 +420,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
let mut a_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, a_size);
@@ -451,17 +459,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -478,7 +488,7 @@ where
}
pub fn test_vec_znx_big_automorphism_inplace<BR: Backend, BT: Backend>(
basek: usize,
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
@@ -512,7 +522,7 @@ pub fn test_vec_znx_big_automorphism_inplace<BR: Backend, BT: Backend>(
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
@@ -536,17 +546,19 @@ pub fn test_vec_znx_big_automorphism_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -561,7 +573,7 @@ pub fn test_vec_znx_big_automorphism_inplace<BR: Backend, BT: Backend>(
}
}
pub fn test_vec_znx_big_negate<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_negate<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>:
VecZnxBigNegate<BR> + VecZnxBigAlloc<BR> + VecZnxBigFromSmall<BR> + VecZnxBigNormalize<BR> + VecZnxBigNormalizeTmpBytes,
@@ -581,7 +593,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
let mut a_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, a_size);
@@ -619,17 +631,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -644,7 +658,7 @@ where
}
}
pub fn test_vec_znx_big_negate_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_negate_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxBigNegateInplace<BR>
+ VecZnxBigAlloc<BR>
@@ -672,7 +686,7 @@ where
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
@@ -695,17 +709,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -719,7 +735,7 @@ where
}
}
pub fn test_vec_znx_big_normalize<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_normalize<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxBigAlloc<BR>
+ VecZnxBigFromSmall<BR>
@@ -772,8 +788,24 @@ where
// Reference
for j in 0..cols {
module_ref.vec_znx_big_normalize(basek, &mut res_ref, j, &a_ref, j, scratch_ref.borrow());
module_test.vec_znx_big_normalize(basek, &mut res_test, j, &a_test, j, scratch_test.borrow());
module_ref.vec_znx_big_normalize(
base2k,
&mut res_ref,
j,
base2k,
&a_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
base2k,
&mut res_test,
j,
base2k,
&a_test,
j,
scratch_test.borrow(),
);
}
assert_eq!(a_ref.digest_u64(), a_ref_digest);
@@ -784,7 +816,7 @@ where
}
}
pub fn test_vec_znx_big_sub<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_sub<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>:
VecZnxBigSub<BR> + VecZnxBigAlloc<BR> + VecZnxBigFromSmall<BR> + VecZnxBigNormalize<BR> + VecZnxBigNormalizeTmpBytes,
@@ -804,7 +836,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
let mut a_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, a_size);
@@ -819,7 +851,7 @@ where
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let mut b_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, b_size);
let mut b_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, b_size);
@@ -859,17 +891,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -885,14 +919,14 @@ where
}
}
pub fn test_vec_znx_big_sub_ab_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_sub_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxBigSubABInplace<BR>
Module<BR>: VecZnxBigSubInplace<BR>
+ VecZnxBigAlloc<BR>
+ VecZnxBigFromSmall<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxBigSubABInplace<BT>
Module<BT>: VecZnxBigSubInplace<BT>
+ VecZnxBigAlloc<BT>
+ VecZnxBigFromSmall<BT>
+ VecZnxBigNormalize<BT>
@@ -911,7 +945,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
let mut a_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, a_size);
@@ -926,7 +960,7 @@ where
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
@@ -937,8 +971,8 @@ where
}
for i in 0..cols {
module_ref.vec_znx_big_sub_ab_inplace(&mut res_big_ref, i, &a_ref, i);
module_test.vec_znx_big_sub_ab_inplace(&mut res_big_test, i, &a_test, i);
module_ref.vec_znx_big_sub_inplace(&mut res_big_ref, i, &a_ref, i);
module_test.vec_znx_big_sub_inplace(&mut res_big_test, i, &a_test, i);
}
assert_eq!(a_ref.digest_u64(), a_ref_digest);
@@ -952,17 +986,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -977,14 +1013,17 @@ where
}
}
pub fn test_vec_znx_big_sub_ba_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxBigSubBAInplace<BR>
pub fn test_vec_znx_big_sub_negate_inplace<BR: Backend, BT: Backend>(
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
Module<BR>: VecZnxBigSubNegateInplace<BR>
+ VecZnxBigAlloc<BR>
+ VecZnxBigFromSmall<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxBigSubBAInplace<BT>
Module<BT>: VecZnxBigSubNegateInplace<BT>
+ VecZnxBigAlloc<BT>
+ VecZnxBigFromSmall<BT>
+ VecZnxBigNormalize<BT>
@@ -1003,7 +1042,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
let mut a_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, a_size);
@@ -1018,7 +1057,7 @@ where
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
@@ -1029,8 +1068,8 @@ where
}
for i in 0..cols {
module_ref.vec_znx_big_sub_ba_inplace(&mut res_big_ref, i, &a_ref, i);
module_test.vec_znx_big_sub_ba_inplace(&mut res_big_test, i, &a_test, i);
module_ref.vec_znx_big_sub_negate_inplace(&mut res_big_ref, i, &a_ref, i);
module_test.vec_znx_big_sub_negate_inplace(&mut res_big_test, i, &a_test, i);
}
assert_eq!(a_ref.digest_u64(), a_ref_digest);
@@ -1044,17 +1083,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -1069,7 +1110,7 @@ where
}
}
pub fn test_vec_znx_big_sub_small_a<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_sub_small_a<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxBigSubSmallA<BR>
+ VecZnxBigAlloc<BR>
@@ -1095,7 +1136,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
let mut a_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, a_size);
@@ -1110,7 +1151,7 @@ where
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -1139,17 +1180,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -1165,7 +1208,7 @@ where
}
}
pub fn test_vec_znx_big_sub_small_b<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_big_sub_small_b<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxBigSubSmallB<BR>
+ VecZnxBigAlloc<BR>
@@ -1191,7 +1234,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let mut a_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, a_size);
let mut a_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, a_size);
@@ -1206,7 +1249,7 @@ where
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -1235,17 +1278,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -1262,16 +1307,16 @@ where
}
pub fn test_vec_znx_big_sub_small_a_inplace<BR: Backend, BT: Backend>(
basek: usize,
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
Module<BR>: VecZnxBigSubSmallAInplace<BR>
Module<BR>: VecZnxBigSubSmallInplace<BR>
+ VecZnxBigAlloc<BR>
+ VecZnxBigFromSmall<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxBigSubSmallAInplace<BT>
Module<BT>: VecZnxBigSubSmallInplace<BT>
+ VecZnxBigAlloc<BT>
+ VecZnxBigFromSmall<BT>
+ VecZnxBigNormalize<BT>
@@ -1290,13 +1335,13 @@ pub fn test_vec_znx_big_sub_small_a_inplace<BR: Backend, BT: Backend>(
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
@@ -1307,8 +1352,8 @@ pub fn test_vec_znx_big_sub_small_a_inplace<BR: Backend, BT: Backend>(
}
for i in 0..cols {
module_ref.vec_znx_big_sub_small_a_inplace(&mut res_big_ref, i, &a, i);
module_test.vec_znx_big_sub_small_a_inplace(&mut res_big_test, i, &a, i);
module_ref.vec_znx_big_sub_small_inplace(&mut res_big_ref, i, &a, i);
module_test.vec_znx_big_sub_small_inplace(&mut res_big_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
@@ -1321,17 +1366,19 @@ pub fn test_vec_znx_big_sub_small_a_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -1347,16 +1394,16 @@ pub fn test_vec_znx_big_sub_small_a_inplace<BR: Backend, BT: Backend>(
}
pub fn test_vec_znx_big_sub_small_b_inplace<BR: Backend, BT: Backend>(
basek: usize,
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
Module<BR>: VecZnxBigSubSmallBInplace<BR>
Module<BR>: VecZnxBigSubSmallNegateInplace<BR>
+ VecZnxBigAlloc<BR>
+ VecZnxBigFromSmall<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxBigSubSmallBInplace<BT>
Module<BT>: VecZnxBigSubSmallNegateInplace<BT>
+ VecZnxBigAlloc<BT>
+ VecZnxBigFromSmall<BT>
+ VecZnxBigNormalize<BT>
@@ -1375,13 +1422,13 @@ pub fn test_vec_znx_big_sub_small_b_inplace<BR: Backend, BT: Backend>(
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let mut res_big_ref: VecZnxBig<Vec<u8>, BR> = module_ref.vec_znx_big_alloc(cols, res_size);
let mut res_big_test: VecZnxBig<Vec<u8>, BT> = module_test.vec_znx_big_alloc(cols, res_size);
@@ -1392,8 +1439,8 @@ pub fn test_vec_znx_big_sub_small_b_inplace<BR: Backend, BT: Backend>(
}
for i in 0..cols {
module_ref.vec_znx_big_sub_small_b_inplace(&mut res_big_ref, i, &a, i);
module_test.vec_znx_big_sub_small_b_inplace(&mut res_big_test, i, &a, i);
module_ref.vec_znx_big_sub_small_negate_inplace(&mut res_big_ref, i, &a, i);
module_test.vec_znx_big_sub_small_negate_inplace(&mut res_big_test, i, &a, i);
}
assert_eq!(a.digest_u64(), a_digest);
@@ -1406,17 +1453,19 @@ pub fn test_vec_znx_big_sub_small_b_inplace<BR: Backend, BT: Backend>(
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),

View File

@@ -3,14 +3,14 @@ use rand::RngCore;
use crate::{
api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAdd,
VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftCopy, VecZnxDftSub, VecZnxDftSubABInplace,
VecZnxDftSubBAInplace, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftApply, VecZnxDftCopy, VecZnxDftSub, VecZnxDftSubInplace,
VecZnxDftSubNegateInplace, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxIdftApplyTmpBytes,
},
layouts::{Backend, DataViewMut, DigestU64, FillUniform, Module, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft},
source::Source,
};
pub fn test_vec_znx_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_dft_add<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftAdd<BR>
+ VecZnxDftAlloc<BR>
@@ -38,7 +38,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
@@ -56,7 +56,7 @@ where
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
let mut b_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, b_size);
@@ -102,17 +102,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -128,7 +130,7 @@ where
}
}
pub fn test_vec_znx_dft_add_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_dft_add_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftAddInplace<BR>
+ VecZnxDftAlloc<BR>
@@ -155,7 +157,7 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
@@ -173,7 +175,7 @@ where
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
@@ -206,17 +208,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -231,7 +235,7 @@ where
}
}
pub fn test_vec_znx_copy<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_copy<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftCopy<BR>
+ VecZnxDftAlloc<BR>
@@ -259,7 +263,7 @@ where
for a_size in [1, 2, 6, 11] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
@@ -307,17 +311,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -333,7 +339,7 @@ where
}
}
pub fn test_vec_znx_idft_apply<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_idft_apply<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftApply<BR>
+ VecZnxDftAlloc<BR>
@@ -361,7 +367,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -406,17 +412,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -432,7 +440,7 @@ where
}
}
pub fn test_vec_znx_idft_apply_tmpa<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_idft_apply_tmpa<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftApply<BR>
+ VecZnxDftAlloc<BR>
@@ -460,7 +468,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -494,17 +502,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -520,7 +530,7 @@ where
}
}
pub fn test_vec_znx_idft_apply_consume<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_idft_apply_consume<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftApply<BR>
+ VecZnxIdftApplyTmpBytes
@@ -550,7 +560,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
for res_size in [1, 2, 3, 4] {
@@ -579,17 +589,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -605,7 +617,7 @@ where
}
}
pub fn test_vec_znx_dft_sub<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_dft_sub<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftSub<BR>
+ VecZnxDftAlloc<BR>
@@ -633,7 +645,7 @@ where
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
@@ -651,7 +663,7 @@ where
for b_size in [1, 2, 3, 4] {
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, b_size);
b.fill_uniform(basek, &mut source);
b.fill_uniform(base2k, &mut source);
let b_digest: u64 = b.digest_u64();
let mut b_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, b_size);
@@ -697,17 +709,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -723,15 +737,15 @@ where
}
}
pub fn test_vec_znx_dft_sub_ab_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vec_znx_dft_sub_inplace<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftSubABInplace<BR>
Module<BR>: VecZnxDftSubInplace<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftSubABInplace<BT>
Module<BT>: VecZnxDftSubInplace<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
@@ -750,7 +764,7 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
@@ -768,7 +782,7 @@ where
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
@@ -783,8 +797,8 @@ where
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_sub_ab_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_sub_ab_inplace(&mut res_dft_test, i, &a_dft_test, i);
module_ref.vec_znx_dft_sub_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_sub_inplace(&mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
@@ -801,17 +815,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -826,15 +842,18 @@ where
}
}
pub fn test_vec_znx_dft_sub_ba_inplace<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: VecZnxDftSubBAInplace<BR>
pub fn test_vec_znx_dft_sub_negate_inplace<BR: Backend, BT: Backend>(
base2k: usize,
module_ref: &Module<BR>,
module_test: &Module<BT>,
) where
Module<BR>: VecZnxDftSubNegateInplace<BR>
+ VecZnxDftAlloc<BR>
+ VecZnxDftApply<BR>
+ VecZnxIdftApplyConsume<BR>
+ VecZnxBigNormalize<BR>
+ VecZnxBigNormalizeTmpBytes,
Module<BT>: VecZnxDftSubBAInplace<BT>
Module<BT>: VecZnxDftSubNegateInplace<BT>
+ VecZnxDftAlloc<BT>
+ VecZnxDftApply<BT>
+ VecZnxIdftApplyConsume<BT>
@@ -853,7 +872,7 @@ where
let mut scratch_test: ScratchOwned<BT> = ScratchOwned::alloc(module_test.vec_znx_big_normalize_tmp_bytes());
for a_size in [1, 2, 3, 4] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, a_size);
@@ -871,7 +890,7 @@ where
for res_size in [1, 2, 3, 4] {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols, res_size);
@@ -886,8 +905,8 @@ where
// Reference
for i in 0..cols {
module_ref.vec_znx_dft_sub_ba_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_sub_ba_inplace(&mut res_dft_test, i, &a_dft_test, i);
module_ref.vec_znx_dft_sub_negate_inplace(&mut res_dft_ref, i, &a_dft_ref, i);
module_test.vec_znx_dft_sub_negate_inplace(&mut res_dft_test, i, &a_dft_test, i);
}
assert_eq!(a_dft_ref.digest_u64(), a_dft_ref_digest);
@@ -904,17 +923,19 @@ where
for j in 0..cols {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),

View File

@@ -11,7 +11,7 @@ use rand::RngCore;
use crate::layouts::{Backend, VecZnxDft, VmpPMat};
pub fn test_vmp_apply_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vmp_apply_dft<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: ModuleNew<BR>
+ VmpApplyDftTmpBytes
@@ -53,11 +53,11 @@ where
let rows: usize = cols_in;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
mat.fill_uniform(basek, &mut source);
mat.fill_uniform(base2k, &mut source);
let mat_digest: u64 = mat.digest_u64();
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
@@ -90,17 +90,19 @@ where
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -117,7 +119,7 @@ where
}
}
pub fn test_vmp_apply_dft_to_dft<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vmp_apply_dft_to_dft<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: ModuleNew<BR>
+ VmpApplyDftToDftTmpBytes
@@ -162,7 +164,7 @@ where
let rows: usize = size_in;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in);
@@ -176,7 +178,7 @@ where
assert_eq!(a.digest_u64(), a_digest);
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
mat.fill_uniform(basek, &mut source);
mat.fill_uniform(base2k, &mut source);
let mat_digest: u64 = mat.digest_u64();
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
@@ -217,17 +219,19 @@ where
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),
@@ -244,7 +248,7 @@ where
}
}
pub fn test_vmp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(basek: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
pub fn test_vmp_apply_dft_to_dft_add<BR: Backend, BT: Backend>(base2k: usize, module_ref: &Module<BR>, module_test: &Module<BT>)
where
Module<BR>: ModuleNew<BR>
+ VmpApplyDftToDftAddTmpBytes
@@ -289,7 +293,7 @@ where
let rows: usize = size_in;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_in, size_in);
a.fill_uniform(basek, &mut source);
a.fill_uniform(base2k, &mut source);
let a_digest: u64 = a.digest_u64();
let mut a_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_in, size_in);
@@ -303,7 +307,7 @@ where
assert_eq!(a.digest_u64(), a_digest);
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, rows, cols_in, cols_out, size_out);
mat.fill_uniform(basek, &mut source);
mat.fill_uniform(base2k, &mut source);
let mat_digest: u64 = mat.digest_u64();
let mut pmat_ref: VmpPMat<Vec<u8>, BR> = module_ref.vmp_pmat_alloc(rows, cols_in, cols_out, size_out);
@@ -316,7 +320,7 @@ where
for limb_offset in 0..size_out {
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols_out, size_out);
res.fill_uniform(basek, &mut source);
res.fill_uniform(base2k, &mut source);
let res_digest: u64 = res.digest_u64();
let mut res_dft_ref: VecZnxDft<Vec<u8>, BR> = module_ref.vec_znx_dft_alloc(cols_out, size_out);
@@ -355,17 +359,19 @@ where
for j in 0..cols_out {
module_ref.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_ref,
j,
base2k,
&res_big_ref,
j,
scratch_ref.borrow(),
);
module_test.vec_znx_big_normalize(
basek,
base2k,
&mut res_small_test,
j,
base2k,
&res_big_test,
j,
scratch_test.borrow(),